Ben Clayton 0a3cda9911 tint: Replace all ProgramBuilder float literals with '_f' suffix
Unsuffixed float literals are currently treated as f32,
but will shortly become AbstractFloat. To keep tests behaving
identically to how they are currently, change all float literals
to explicitly use the f32 '_f' suffix.

Bug: tint:1504
Change-Id: I2a00725ee1b34a6efbe15ac4ba438c00c4416dd8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89402
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
2022-05-10 17:30:15 +00:00

5892 lines
251 KiB
C++

// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/reader/spirv/function.h"
#include <algorithm>
#include <array>
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/bitcast_expression.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/ast/builtin.h"
#include "src/tint/ast/builtin_attribute.h"
#include "src/tint/ast/call_statement.h"
#include "src/tint/ast/continue_statement.h"
#include "src/tint/ast/discard_statement.h"
#include "src/tint/ast/fallthrough_statement.h"
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/loop_statement.h"
#include "src/tint/ast/return_statement.h"
#include "src/tint/ast/stage_attribute.h"
#include "src/tint/ast/switch_statement.h"
#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/sem/builtin_type.h"
#include "src/tint/sem/depth_texture.h"
#include "src/tint/sem/sampled_texture.h"
// Terms:
// CFG: the control flow graph of the function, where basic blocks are the
// nodes, and branches form the directed arcs. The function entry block is
// the root of the CFG.
//
// Suppose H is a header block (i.e. has an OpSelectionMerge or OpLoopMerge).
// Then:
// - Let M(H) be the merge block named by the merge instruction in H.
// - If H is a loop header, i.e. has an OpLoopMerge instruction, then let
// CT(H) be the continue target block named by the OpLoopMerge
// instruction.
// - If H is a selection construct whose header ends in
// OpBranchConditional with true target %then and false target %else,
// then TT(H) = %then and FT(H) = %else
//
// Determining output block order:
// The "structured post-order traversal" of the CFG is a post-order traversal
// of the basic blocks in the CFG, where:
// We visit the entry node of the function first.
// When visiting a header block:
// We next visit its merge block
// Then if it's a loop header, we next visit the continue target,
// Then we visit the block's successors (whether it's a header or not)
// If the block ends in an OpBranchConditional, we visit the false target
// before the true target.
//
// The "reverse structured post-order traversal" of the CFG is the reverse
// of the structured post-order traversal.
// This is the order of basic blocks as they should be emitted to the WGSL
// function. It is the order computed by ComputeBlockOrder, and stored in
// the |FunctionEmiter::block_order_|.
// Blocks not in this ordering are ignored by the rest of the algorithm.
//
// Note:
// - A block D in the function might not appear in this order because
// no block in the order branches to D.
// - An unreachable block D might still be in the order because some header
// block in the order names D as its continue target, or merge block,
// or D is reachable from one of those otherwise-unreachable continue
// targets or merge blocks.
//
// Terms:
// Let Pos(B) be the index position of a block B in the computed block order.
//
// CFG intervals and valid nesting:
//
// A correctly structured CFG satisfies nesting rules that we can check by
// comparing positions of related blocks.
//
// If header block H is in the block order, then the following holds:
//
// Pos(H) < Pos(M(H))
//
// If CT(H) exists, then:
//
// Pos(H) <= Pos(CT(H))
// Pos(CT(H)) < Pos(M)
//
// This gives us the fundamental ordering of blocks in relation to a
// structured construct:
// The blocks before H in the block order, are not in the construct
// The blocks at M(H) or later in the block order, are not in the construct
// The blocks in a selection headed at H are in positions [ Pos(H),
// Pos(M(H)) ) The blocks in a loop construct headed at H are in positions
// [ Pos(H), Pos(CT(H)) ) The blocks in the continue construct for loop
// headed at H are in
// positions [ Pos(CT(H)), Pos(M(H)) )
//
// Schematically, for a selection construct headed by H, the blocks are in
// order from left to right:
//
// ...a-b-c H d-e-f M(H) n-o-p...
//
// where ...a-b-c: blocks before the selection construct
// where H and d-e-f: blocks in the selection construct
// where M(H) and n-o-p...: blocks after the selection construct
//
// Schematically, for a loop construct headed by H that is its own
// continue construct, the blocks in order from left to right:
//
// ...a-b-c H=CT(H) d-e-f M(H) n-o-p...
//
// where ...a-b-c: blocks before the loop
// where H is the continue construct; CT(H)=H, and the loop construct
// is *empty*
// where d-e-f... are other blocks in the continue construct
// where M(H) and n-o-p...: blocks after the continue construct
//
// Schematically, for a multi-block loop construct headed by H, there are
// blocks in order from left to right:
//
// ...a-b-c H d-e-f CT(H) j-k-l M(H) n-o-p...
//
// where ...a-b-c: blocks before the loop
// where H and d-e-f: blocks in the loop construct
// where CT(H) and j-k-l: blocks in the continue construct
// where M(H) and n-o-p...: blocks after the loop and continue
// constructs
//
using namespace tint::number_suffixes; // NOLINT
namespace tint::reader::spirv {
namespace {
constexpr uint32_t kMaxVectorLen = 4;
// Gets the AST unary opcode for the given SPIR-V opcode, if any
// @param opcode SPIR-V opcode
// @param ast_unary_op return parameter
// @returns true if it was a unary operation
bool GetUnaryOp(SpvOp opcode, ast::UnaryOp* ast_unary_op) {
switch (opcode) {
case SpvOpSNegate:
case SpvOpFNegate:
*ast_unary_op = ast::UnaryOp::kNegation;
return true;
case SpvOpLogicalNot:
*ast_unary_op = ast::UnaryOp::kNot;
return true;
case SpvOpNot:
*ast_unary_op = ast::UnaryOp::kComplement;
return true;
default:
break;
}
return false;
}
/// Converts a SPIR-V opcode for a WGSL builtin function, if there is a
/// direct translation. Returns nullptr otherwise.
/// @returns the WGSL builtin function name for the given opcode, or nullptr.
const char* GetUnaryBuiltInFunctionName(SpvOp opcode) {
switch (opcode) {
case SpvOpAny:
return "any";
case SpvOpAll:
return "all";
case SpvOpIsNan:
return "isNan";
case SpvOpIsInf:
return "isInf";
case SpvOpTranspose:
return "transpose";
default:
break;
}
return nullptr;
}
// Converts a SPIR-V opcode to its corresponding AST binary opcode, if any
// @param opcode SPIR-V opcode
// @returns the AST binary op for the given opcode, or kNone
ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
switch (opcode) {
case SpvOpIAdd:
case SpvOpFAdd:
return ast::BinaryOp::kAdd;
case SpvOpISub:
case SpvOpFSub:
return ast::BinaryOp::kSubtract;
case SpvOpIMul:
case SpvOpFMul:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
return ast::BinaryOp::kMultiply;
case SpvOpUDiv:
case SpvOpSDiv:
case SpvOpFDiv:
return ast::BinaryOp::kDivide;
case SpvOpUMod:
case SpvOpSMod:
case SpvOpFRem:
return ast::BinaryOp::kModulo;
case SpvOpLogicalEqual:
case SpvOpIEqual:
case SpvOpFOrdEqual:
return ast::BinaryOp::kEqual;
case SpvOpLogicalNotEqual:
case SpvOpINotEqual:
case SpvOpFOrdNotEqual:
return ast::BinaryOp::kNotEqual;
case SpvOpBitwiseAnd:
return ast::BinaryOp::kAnd;
case SpvOpBitwiseOr:
return ast::BinaryOp::kOr;
case SpvOpBitwiseXor:
return ast::BinaryOp::kXor;
case SpvOpLogicalAnd:
return ast::BinaryOp::kAnd;
case SpvOpLogicalOr:
return ast::BinaryOp::kOr;
case SpvOpUGreaterThan:
case SpvOpSGreaterThan:
case SpvOpFOrdGreaterThan:
return ast::BinaryOp::kGreaterThan;
case SpvOpUGreaterThanEqual:
case SpvOpSGreaterThanEqual:
case SpvOpFOrdGreaterThanEqual:
return ast::BinaryOp::kGreaterThanEqual;
case SpvOpULessThan:
case SpvOpSLessThan:
case SpvOpFOrdLessThan:
return ast::BinaryOp::kLessThan;
case SpvOpULessThanEqual:
case SpvOpSLessThanEqual:
case SpvOpFOrdLessThanEqual:
return ast::BinaryOp::kLessThanEqual;
default:
break;
}
// It's not clear what OpSMod should map to.
// https://bugs.chromium.org/p/tint/issues/detail?id=52
return ast::BinaryOp::kNone;
}
// If the given SPIR-V opcode is a floating point unordered comparison,
// then returns the binary float comparison for which it is the negation.
// Othewrise returns BinaryOp::kNone.
// @param opcode SPIR-V opcode
// @returns operation corresponding to negated version of the SPIR-V opcode
ast::BinaryOp NegatedFloatCompare(SpvOp opcode) {
switch (opcode) {
case SpvOpFUnordEqual:
return ast::BinaryOp::kNotEqual;
case SpvOpFUnordNotEqual:
return ast::BinaryOp::kEqual;
case SpvOpFUnordLessThan:
return ast::BinaryOp::kGreaterThanEqual;
case SpvOpFUnordLessThanEqual:
return ast::BinaryOp::kGreaterThan;
case SpvOpFUnordGreaterThan:
return ast::BinaryOp::kLessThanEqual;
case SpvOpFUnordGreaterThanEqual:
return ast::BinaryOp::kLessThan;
default:
break;
}
return ast::BinaryOp::kNone;
}
// Returns the WGSL standard library function for the given
// GLSL.std.450 extended instruction operation code. Unknown
// and invalid opcodes map to the empty string.
// @returns the WGSL standard function name, or an empty string.
std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
switch (ext_opcode) {
case GLSLstd450FAbs:
case GLSLstd450SAbs:
return "abs";
case GLSLstd450Acos:
return "acos";
case GLSLstd450Asin:
return "asin";
case GLSLstd450Atan:
return "atan";
case GLSLstd450Atan2:
return "atan2";
case GLSLstd450Ceil:
return "ceil";
case GLSLstd450UClamp:
case GLSLstd450SClamp:
case GLSLstd450NClamp:
case GLSLstd450FClamp: // FClamp is less prescriptive about NaN operands
return "clamp";
case GLSLstd450Cos:
return "cos";
case GLSLstd450Cosh:
return "cosh";
case GLSLstd450Cross:
return "cross";
case GLSLstd450Degrees:
return "degrees";
case GLSLstd450Distance:
return "distance";
case GLSLstd450Exp:
return "exp";
case GLSLstd450Exp2:
return "exp2";
case GLSLstd450FaceForward:
return "faceForward";
case GLSLstd450Floor:
return "floor";
case GLSLstd450Fma:
return "fma";
case GLSLstd450Fract:
return "fract";
case GLSLstd450InverseSqrt:
return "inverseSqrt";
case GLSLstd450Ldexp:
return "ldexp";
case GLSLstd450Length:
return "length";
case GLSLstd450Log:
return "log";
case GLSLstd450Log2:
return "log2";
case GLSLstd450NMax:
case GLSLstd450FMax: // FMax is less prescriptive about NaN operands
case GLSLstd450UMax:
case GLSLstd450SMax:
return "max";
case GLSLstd450NMin:
case GLSLstd450FMin: // FMin is less prescriptive about NaN operands
case GLSLstd450UMin:
case GLSLstd450SMin:
return "min";
case GLSLstd450FMix:
return "mix";
case GLSLstd450Normalize:
return "normalize";
case GLSLstd450PackSnorm4x8:
return "pack4x8snorm";
case GLSLstd450PackUnorm4x8:
return "pack4x8unorm";
case GLSLstd450PackSnorm2x16:
return "pack2x16snorm";
case GLSLstd450PackUnorm2x16:
return "pack2x16unorm";
case GLSLstd450PackHalf2x16:
return "pack2x16float";
case GLSLstd450Pow:
return "pow";
case GLSLstd450FSign:
return "sign";
case GLSLstd450Radians:
return "radians";
case GLSLstd450Reflect:
return "reflect";
case GLSLstd450Refract:
return "refract";
case GLSLstd450Round:
case GLSLstd450RoundEven:
return "round";
case GLSLstd450Sin:
return "sin";
case GLSLstd450Sinh:
return "sinh";
case GLSLstd450SmoothStep:
return "smoothstep";
case GLSLstd450Sqrt:
return "sqrt";
case GLSLstd450Step:
return "step";
case GLSLstd450Tan:
return "tan";
case GLSLstd450Tanh:
return "tanh";
case GLSLstd450Trunc:
return "trunc";
case GLSLstd450UnpackSnorm4x8:
return "unpack4x8snorm";
case GLSLstd450UnpackUnorm4x8:
return "unpack4x8unorm";
case GLSLstd450UnpackSnorm2x16:
return "unpack2x16snorm";
case GLSLstd450UnpackUnorm2x16:
return "unpack2x16unorm";
case GLSLstd450UnpackHalf2x16:
return "unpack2x16float";
default:
// TODO(dneto) - The following are not implemented.
// They are grouped semantically, as in GLSL.std.450.h.
case GLSLstd450SSign:
case GLSLstd450Asinh:
case GLSLstd450Acosh:
case GLSLstd450Atanh:
case GLSLstd450Determinant:
case GLSLstd450MatrixInverse:
case GLSLstd450Modf:
case GLSLstd450ModfStruct:
case GLSLstd450IMix:
case GLSLstd450Frexp:
case GLSLstd450FrexpStruct:
case GLSLstd450PackDouble2x32:
case GLSLstd450UnpackDouble2x32:
case GLSLstd450FindILsb:
case GLSLstd450FindSMsb:
case GLSLstd450FindUMsb:
case GLSLstd450InterpolateAtCentroid:
case GLSLstd450InterpolateAtSample:
case GLSLstd450InterpolateAtOffset:
break;
}
return "";
}
// Returns the WGSL standard library function builtin for the
// given instruction, or sem::BuiltinType::kNone
sem::BuiltinType GetBuiltin(SpvOp opcode) {
switch (opcode) {
case SpvOpBitCount:
return sem::BuiltinType::kCountOneBits;
case SpvOpBitFieldInsert:
return sem::BuiltinType::kInsertBits;
case SpvOpBitFieldSExtract:
case SpvOpBitFieldUExtract:
return sem::BuiltinType::kExtractBits;
case SpvOpBitReverse:
return sem::BuiltinType::kReverseBits;
case SpvOpDot:
return sem::BuiltinType::kDot;
case SpvOpDPdx:
return sem::BuiltinType::kDpdx;
case SpvOpDPdy:
return sem::BuiltinType::kDpdy;
case SpvOpFwidth:
return sem::BuiltinType::kFwidth;
case SpvOpDPdxFine:
return sem::BuiltinType::kDpdxFine;
case SpvOpDPdyFine:
return sem::BuiltinType::kDpdyFine;
case SpvOpFwidthFine:
return sem::BuiltinType::kFwidthFine;
case SpvOpDPdxCoarse:
return sem::BuiltinType::kDpdxCoarse;
case SpvOpDPdyCoarse:
return sem::BuiltinType::kDpdyCoarse;
case SpvOpFwidthCoarse:
return sem::BuiltinType::kFwidthCoarse;
default:
break;
}
return sem::BuiltinType::kNone;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image access instruction
// whose first input operand is an OpSampledImage value.
bool IsSampledImageAccess(SpvOp opcode) {
switch (opcode) {
case SpvOpImageSampleImplicitLod:
case SpvOpImageSampleExplicitLod:
case SpvOpImageSampleDrefImplicitLod:
case SpvOpImageSampleDrefExplicitLod:
// WGSL doesn't have *Proj* texturing; spirv reader emulates it.
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
case SpvOpImageGather:
case SpvOpImageDrefGather:
case SpvOpImageQueryLod:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image sampling, gather,
// or gather-compare operation.
bool IsImageSamplingOrGatherOrDrefGather(SpvOp opcode) {
switch (opcode) {
case SpvOpImageSampleImplicitLod:
case SpvOpImageSampleExplicitLod:
case SpvOpImageSampleDrefImplicitLod:
case SpvOpImageSampleDrefExplicitLod:
// WGSL doesn't have *Proj* texturing; spirv reader emulates it.
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
case SpvOpImageGather:
case SpvOpImageDrefGather:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image access instruction
// whose first input operand is an OpImage value.
bool IsRawImageAccess(SpvOp opcode) {
switch (opcode) {
case SpvOpImageRead:
case SpvOpImageWrite:
case SpvOpImageFetch:
return true;
default:
break;
}
return false;
}
// @param opcode a SPIR-V opcode
// @returns true if the given instruction is an image query instruction
bool IsImageQuery(SpvOp opcode) {
switch (opcode) {
case SpvOpImageQuerySize:
case SpvOpImageQuerySizeLod:
case SpvOpImageQueryLevels:
case SpvOpImageQuerySamples:
case SpvOpImageQueryLod:
return true;
default:
break;
}
return false;
}
// @returns the merge block ID for the given basic block, or 0 if there is none.
uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
// Get the OpSelectionMerge or OpLoopMerge instruction, if any.
auto* inst = bb.GetMergeInst();
return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0);
}
// @returns the continue target ID for the given basic block, or 0 if there
// is none.
uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) {
// Get the OpLoopMerge instruction, if any.
auto* inst = bb.GetLoopMergeInst();
return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1);
}
// A structured traverser produces the reverse structured post-order of the
// CFG of a function. The blocks traversed are the transitive closure (minimum
// fixed point) of:
// - the entry block
// - a block reached by a branch from another block in the set
// - a block mentioned as a merge block or continue target for a block in the
// set
class StructuredTraverser {
public:
explicit StructuredTraverser(const spvtools::opt::Function& function) : function_(function) {
for (auto& block : function_) {
id_to_block_[block.id()] = &block;
}
}
// Returns the reverse postorder traversal of the CFG, where:
// - a merge block always follows its associated constructs
// - a continue target always follows the associated loop construct, if any
// @returns the IDs of blocks in reverse structured post order
std::vector<uint32_t> ReverseStructuredPostOrder() {
visit_order_.clear();
visited_.clear();
VisitBackward(function_.entry()->id());
std::vector<uint32_t> order(visit_order_.rbegin(), visit_order_.rend());
return order;
}
private:
// Executes a depth first search of the CFG, where right after we visit a
// header, we will visit its merge block, then its continue target (if any).
// Also records the post order ordering.
void VisitBackward(uint32_t id) {
if (id == 0)
return;
if (visited_.count(id))
return;
visited_.insert(id);
const spvtools::opt::BasicBlock* bb = id_to_block_[id]; // non-null for valid modules
VisitBackward(MergeFor(*bb));
VisitBackward(ContinueTargetFor(*bb));
// Visit successors. We will naturally skip the continue target and merge
// blocks.
auto* terminator = bb->terminator();
auto opcode = terminator->opcode();
if (opcode == SpvOpBranchConditional) {
// Visit the false branch, then the true branch, to make them come
// out in the natural order for an "if".
VisitBackward(terminator->GetSingleWordInOperand(2));
VisitBackward(terminator->GetSingleWordInOperand(1));
} else if (opcode == SpvOpBranch) {
VisitBackward(terminator->GetSingleWordInOperand(0));
} else if (opcode == SpvOpSwitch) {
// TODO(dneto): Consider visiting the labels in literal-value order.
std::vector<uint32_t> successors;
bb->ForEachSuccessorLabel(
[&successors](const uint32_t succ_id) { successors.push_back(succ_id); });
for (auto succ_id : successors) {
VisitBackward(succ_id);
}
}
visit_order_.push_back(id);
}
const spvtools::opt::Function& function_;
std::unordered_map<uint32_t, const spvtools::opt::BasicBlock*> id_to_block_;
std::vector<uint32_t> visit_order_;
std::unordered_set<uint32_t> visited_;
};
/// A StatementBuilder for ast::SwitchStatement
/// @see StatementBuilder
struct SwitchStatementBuilder final : public Castable<SwitchStatementBuilder, StatementBuilder> {
/// Constructor
/// @param cond the switch statement condition
explicit SwitchStatementBuilder(const ast::Expression* cond) : condition(cond) {}
/// @param builder the program builder
/// @returns the built ast::SwitchStatement
const ast::SwitchStatement* Build(ProgramBuilder* builder) const override {
// We've listed cases in reverse order in the switch statement.
// Reorder them to match the presentation order in WGSL.
auto reversed_cases = cases;
std::reverse(reversed_cases.begin(), reversed_cases.end());
return builder->create<ast::SwitchStatement>(Source{}, condition, reversed_cases);
}
/// Switch statement condition
const ast::Expression* const condition;
/// Switch statement cases
ast::CaseStatementList cases;
};
/// A StatementBuilder for ast::IfStatement
/// @see StatementBuilder
struct IfStatementBuilder final : public Castable<IfStatementBuilder, StatementBuilder> {
/// Constructor
/// @param c the if-statement condition
explicit IfStatementBuilder(const ast::Expression* c) : cond(c) {}
/// @param builder the program builder
/// @returns the built ast::IfStatement
const ast::IfStatement* Build(ProgramBuilder* builder) const override {
return builder->create<ast::IfStatement>(Source{}, cond, body, else_stmt);
}
/// If-statement condition
const ast::Expression* const cond;
/// If-statement block body
const ast::BlockStatement* body = nullptr;
/// Optional if-statement else statement
const ast::Statement* else_stmt = nullptr;
};
/// A StatementBuilder for ast::LoopStatement
/// @see StatementBuilder
struct LoopStatementBuilder final : public Castable<LoopStatementBuilder, StatementBuilder> {
/// @param builder the program builder
/// @returns the built ast::LoopStatement
ast::LoopStatement* Build(ProgramBuilder* builder) const override {
return builder->create<ast::LoopStatement>(Source{}, body, continuing);
}
/// Loop-statement block body
const ast::BlockStatement* body = nullptr;
/// Loop-statement continuing body
/// @note the mutable keyword here is required as all non-StatementBuilders
/// `ast::Node`s are immutable and are referenced with `const` pointers.
/// StatementBuilders however exist to provide mutable state while the
/// FunctionEmitter is building the function. All StatementBuilders are
/// replaced with immutable AST nodes when Finalize() is called.
mutable const ast::BlockStatement* continuing = nullptr;
};
/// @param decos a list of parsed decorations
/// @returns true if the decorations include a SampleMask builtin
bool HasBuiltinSampleMask(const ast::AttributeList& decos) {
if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(decos)) {
return builtin->builtin == ast::Builtin::kSampleMask;
}
return false;
}
} // namespace
BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb) : basic_block(&bb), id(bb.id()) {}
BlockInfo::~BlockInfo() = default;
DefInfo::DefInfo(const spvtools::opt::Instruction& def_inst,
uint32_t the_block_pos,
size_t the_index)
: inst(def_inst), block_pos(the_block_pos), index(the_index) {}
DefInfo::~DefInfo() = default;
ast::Node* StatementBuilder::Clone(CloneContext*) const {
return nullptr;
}
FunctionEmitter::FunctionEmitter(ParserImpl* pi,
const spvtools::opt::Function& function,
const EntryPointInfo* ep_info)
: parser_impl_(*pi),
ty_(pi->type_manager()),
builder_(pi->builder()),
ir_context_(*(pi->ir_context())),
def_use_mgr_(ir_context_.get_def_use_mgr()),
constant_mgr_(ir_context_.get_constant_mgr()),
type_mgr_(ir_context_.get_type_mgr()),
fail_stream_(pi->fail_stream()),
namer_(pi->namer()),
function_(function),
sample_mask_in_id(0u),
sample_mask_out_id(0u),
ep_info_(ep_info) {
PushNewStatementBlock(nullptr, 0, nullptr);
}
FunctionEmitter::FunctionEmitter(ParserImpl* pi, const spvtools::opt::Function& function)
: FunctionEmitter(pi, function, nullptr) {}
FunctionEmitter::FunctionEmitter(FunctionEmitter&& other)
: parser_impl_(other.parser_impl_),
ty_(other.ty_),
builder_(other.builder_),
ir_context_(other.ir_context_),
def_use_mgr_(ir_context_.get_def_use_mgr()),
constant_mgr_(ir_context_.get_constant_mgr()),
type_mgr_(ir_context_.get_type_mgr()),
fail_stream_(other.fail_stream_),
namer_(other.namer_),
function_(other.function_),
sample_mask_in_id(other.sample_mask_out_id),
sample_mask_out_id(other.sample_mask_in_id),
ep_info_(other.ep_info_) {
other.statements_stack_.clear();
PushNewStatementBlock(nullptr, 0, nullptr);
}
FunctionEmitter::~FunctionEmitter() = default;
FunctionEmitter::StatementBlock::StatementBlock(const Construct* construct,
uint32_t end_id,
FunctionEmitter::CompletionAction completion_action)
: construct_(construct), end_id_(end_id), completion_action_(completion_action) {}
FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other) = default;
FunctionEmitter::StatementBlock::~StatementBlock() = default;
void FunctionEmitter::StatementBlock::Finalize(ProgramBuilder* pb) {
TINT_ASSERT(Reader, !finalized_ /* Finalize() must only be called once */);
for (size_t i = 0; i < statements_.size(); i++) {
if (auto* sb = statements_[i]->As<StatementBuilder>()) {
statements_[i] = sb->Build(pb);
}
}
if (completion_action_ != nullptr) {
completion_action_(statements_);
}
finalized_ = true;
}
void FunctionEmitter::StatementBlock::Add(const ast::Statement* statement) {
TINT_ASSERT(Reader, !finalized_ /* Add() must not be called after Finalize() */);
statements_.emplace_back(statement);
}
void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
uint32_t end_id,
CompletionAction action) {
statements_stack_.emplace_back(StatementBlock{construct, end_id, action});
}
void FunctionEmitter::PushGuard(const std::string& guard_name, uint32_t end_id) {
TINT_ASSERT(Reader, !statements_stack_.empty());
TINT_ASSERT(Reader, !guard_name.empty());
// Guard control flow by the guard variable. Introduce a new
// if-selection with a then-clause ending at the same block
// as the statement block at the top of the stack.
const auto& top = statements_stack_.back();
auto* cond =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(guard_name));
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
PushNewStatementBlock(top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
}
void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
TINT_ASSERT(Reader, !statements_stack_.empty());
const auto& top = statements_stack_.back();
auto* cond = MakeTrue(Source{});
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
PushNewStatementBlock(top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
}
const ast::StatementList FunctionEmitter::ast_body() {
TINT_ASSERT(Reader, !statements_stack_.empty());
auto& entry = statements_stack_[0];
entry.Finalize(&builder_);
return entry.GetStatements();
}
const ast::Statement* FunctionEmitter::AddStatement(const ast::Statement* statement) {
TINT_ASSERT(Reader, !statements_stack_.empty());
if (statement != nullptr) {
statements_stack_.back().Add(statement);
}
return statement;
}
const ast::Statement* FunctionEmitter::LastStatement() {
TINT_ASSERT(Reader, !statements_stack_.empty());
auto& statement_list = statements_stack_.back().GetStatements();
TINT_ASSERT(Reader, !statement_list.empty());
return statement_list.back();
}
bool FunctionEmitter::Emit() {
if (failed()) {
return false;
}
// We only care about functions with bodies.
if (function_.cbegin() == function_.cend()) {
return true;
}
// The function declaration, corresponding to how it's written in SPIR-V,
// and without regard to whether it's an entry point.
FunctionDeclaration decl;
if (!ParseFunctionDeclaration(&decl)) {
return false;
}
bool make_body_function = true;
if (ep_info_) {
TINT_ASSERT(Reader, !ep_info_->inner_name.empty());
if (ep_info_->owns_inner_implementation) {
// This is an entry point, and we want to emit it as a wrapper around
// an implementation function.
decl.name = ep_info_->inner_name;
} else {
// This is a second entry point that shares an inner implementation
// function.
make_body_function = false;
}
}
if (make_body_function) {
auto* body = MakeFunctionBody();
if (!body) {
return false;
}
builder_.AST().AddFunction(
create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name),
std::move(decl.params), decl.return_type->Build(builder_), body,
std::move(decl.attributes), ast::AttributeList{}));
}
if (ep_info_ && !ep_info_->inner_name.empty()) {
return EmitEntryPointAsWrapper();
}
return success();
}
const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
TINT_ASSERT(Reader, statements_stack_.size() == 1);
if (!EmitBody()) {
return nullptr;
}
// Set the body of the AST function node.
if (statements_stack_.size() != 1) {
Fail() << "internal error: statement-list stack should have 1 "
"element but has "
<< statements_stack_.size();
return nullptr;
}
statements_stack_[0].Finalize(&builder_);
auto& statements = statements_stack_[0].GetStatements();
auto* body = create<ast::BlockStatement>(Source{}, statements);
// Maintain the invariant by repopulating the one and only element.
statements_stack_.clear();
PushNewStatementBlock(constructs_[0].get(), 0, nullptr);
return body;
}
bool FunctionEmitter::EmitPipelineInput(std::string var_name,
const Type* var_type,
ast::AttributeList* attrs,
std::vector<int> index_prefix,
const Type* tip_type,
const Type* forced_param_type,
ast::VariableList* params,
ast::StatementList* statements) {
// TODO(dneto): Handle structs where the locations are annotated on members.
tip_type = tip_type->UnwrapAlias();
if (auto* ref_type = tip_type->As<Reference>()) {
tip_type = ref_type->type;
}
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
[&](const Matrix* matrix_type) -> bool {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix, vec_ty,
forced_param_type, params, statements)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix, elem_ty,
forced_param_type, params, statements)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
ast::AttributeList member_attrs(*attrs);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type, parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_attrs)) {
return false;
}
if (!EmitPipelineInput(var_name, var_type, &member_attrs, index_prefix, members[i],
forced_param_type, params, statements)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
const Type* param_type = is_builtin ? forced_param_type : tip_type;
const auto param_name = namer_.MakeDerivedName(var_name + "_param");
// Create the parameter.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
params->push_back(builder_.Param(param_name, param_type->Build(builder_), *attrs));
// Add a body statement to copy the parameter to the corresponding
// private variable.
const ast::Expression* param_value = builder_.Expr(param_name);
const ast::Expression* store_dest = builder_.Expr(var_name);
// Index into the LHS as needed.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(i32(index)));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(i32(index)));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
store_dest = builder_.MemberAccessor(
store_dest,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[index];
});
}
if (is_builtin && (tip_type != forced_param_type)) {
// The parameter will have the WGSL type, but we need bitcast to
// the variable store type.
param_value =
create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
}
statements->push_back(builder_.Assign(store_dest, param_value));
// Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(attrs);
return success();
});
}
void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
for (auto*& attr : *attributes) {
if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
// Replace this location attribute with a new one with one higher index.
// The old one doesn't leak because it's kept in the builder's AST node
// list.
attr = builder_.Location(loc_attr->source, loc_attr->value + 1);
}
}
}
const ast::Attribute* FunctionEmitter::GetLocation(const ast::AttributeList& attributes) {
for (auto* const& attr : attributes) {
if (attr->Is<ast::LocationAttribute>()) {
return attr;
}
}
return nullptr;
}
bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
const Type* var_type,
ast::AttributeList* decos,
std::vector<int> index_prefix,
const Type* tip_type,
const Type* forced_member_type,
ast::StructMemberList* return_members,
ast::ExpressionList* return_exprs) {
tip_type = tip_type->UnwrapAlias();
if (auto* ref_type = tip_type->As<Reference>()) {
tip_type = ref_type->type;
}
// Recursively flatten matrices, arrays, and structures.
return Switch(
tip_type,
[&](const Matrix* matrix_type) {
index_prefix.push_back(0);
const auto num_columns = static_cast<int>(matrix_type->columns);
const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
for (int col = 0; col < num_columns; col++) {
index_prefix.back() = col;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
forced_member_type, return_members, return_exprs)) {
return false;
}
}
return success();
},
[&](const Array* array_type) -> bool {
if (array_type->size == 0) {
return Fail() << "runtime-size array not allowed on pipeline IO";
}
index_prefix.push_back(0);
const Type* elem_ty = array_type->type;
for (int i = 0; i < static_cast<int>(array_type->size); i++) {
index_prefix.back() = i;
if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
forced_member_type, return_members, return_exprs)) {
return false;
}
}
return success();
},
[&](const Struct* struct_type) -> bool {
const auto& members = struct_type->members;
index_prefix.push_back(0);
for (int i = 0; i < static_cast<int>(members.size()); ++i) {
index_prefix.back() = i;
ast::AttributeList member_attrs(*decos);
if (!parser_impl_.ConvertPipelineDecorations(
struct_type, parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
&member_attrs)) {
return false;
}
if (!EmitPipelineOutput(var_name, var_type, &member_attrs, index_prefix, members[i],
forced_member_type, return_members, return_exprs)) {
return false;
}
// Copy the location as updated by nested expansion of the member.
parser_impl_.SetLocation(decos, GetLocation(member_attrs));
}
return success();
},
[&](Default) {
const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
const Type* member_type = is_builtin ? forced_member_type : tip_type;
// Derive the member name directly from the variable name. They can't
// collide.
const auto member_name = namer_.MakeDerivedName(var_name);
// Create the member.
// TODO(dneto): Note: If the parameter has non-location decorations,
// then those decoration AST nodes will be reused between multiple
// elements of a matrix, array, or structure. Normally that's
// disallowed but currently the SPIR-V reader will make duplicates when
// the entire AST is cloned at the top level of the SPIR-V reader flow.
// Consider rewriting this to avoid this node-sharing.
return_members->push_back(
builder_.Member(member_name, member_type->Build(builder_), *decos));
// Create an expression to evaluate the part of the variable indexed by
// the index_prefix.
const ast::Expression* load_source = builder_.Expr(var_name);
// Index into the variable as needed to pick out the flattened member.
auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
for (auto index : index_prefix) {
Switch(
current_type,
[&](const Matrix* matrix_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(i32(index)));
current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
},
[&](const Array* array_type) {
load_source =
builder_.IndexAccessor(load_source, builder_.Expr(i32(index)));
current_type = array_type->type->UnwrapAlias();
},
[&](const Struct* struct_type) {
load_source = builder_.MemberAccessor(
load_source,
builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
current_type = struct_type->members[index];
});
}
if (is_builtin && (tip_type != forced_member_type)) {
// The member will have the WGSL type, but we need bitcast to
// the variable store type.
load_source = create<ast::BitcastExpression>(forced_member_type->Build(builder_),
load_source);
}
return_exprs->push_back(load_source);
// Increment the location attribute, in case more parameters will
// follow.
IncrementLocation(decos);
return success();
});
}
bool FunctionEmitter::EmitEntryPointAsWrapper() {
Source source;
// The statements in the body.
ast::StatementList stmts;
FunctionDeclaration decl;
decl.source = source;
decl.name = ep_info_->name;
const ast::Type* return_type = nullptr; // Populated below.
// Pipeline inputs become parameters to the wrapper function, and
// their values are saved into the corresponding private variables that
// have already been created.
for (uint32_t var_id : ep_info_->inputs) {
const auto* var = def_use_mgr_->GetDef(var_id);
TINT_ASSERT(Reader, var != nullptr);
TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
auto* store_type = GetVariableStoreType(*var);
auto* forced_param_type = store_type;
ast::AttributeList param_decos;
if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_param_type, &param_decos,
true)) {
// This occurs, and is not an error, for the PointSize builtin.
if (!success()) {
// But exit early if an error was logged.
return false;
}
continue;
}
// We don't have to handle initializers because in Vulkan SPIR-V, Input
// variables must not have them.
const auto var_name = namer_.GetName(var_id);
bool ok = true;
if (HasBuiltinSampleMask(param_decos)) {
// In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a scalar.
// Use the first element only.
auto* sample_mask_array_type = store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(Reader, sample_mask_array_type);
ok = EmitPipelineInput(var_name, store_type, &param_decos, {0},
sample_mask_array_type->type, forced_param_type, &(decl.params),
&stmts);
} else {
// The normal path.
ok = EmitPipelineInput(var_name, store_type, &param_decos, {}, store_type,
forced_param_type, &(decl.params), &stmts);
}
if (!ok) {
return false;
}
}
// Call the inner function. It has no parameters.
stmts.push_back(create<ast::CallStatement>(
source,
create<ast::CallExpression>(source,
create<ast::IdentifierExpression>(
source, builder_.Symbols().Register(ep_info_->inner_name)),
ast::ExpressionList{})));
// Pipeline outputs are mapped to the return value.
if (ep_info_->outputs.empty()) {
// There is nothing to return.
return_type = ty_.Void()->Build(builder_);
} else {
// Pipeline outputs are converted to a structure that is written
// to just before returning.
const auto return_struct_name = namer_.MakeDerivedName(ep_info_->name + "_out");
const auto return_struct_sym = builder_.Symbols().Register(return_struct_name);
// Define the structure.
std::vector<const ast::StructMember*> return_members;
ast::ExpressionList return_exprs;
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
for (uint32_t var_id : ep_info_->outputs) {
if (var_id == builtin_position_info.per_vertex_var_id) {
// The SPIR-V gl_PerVertex variable has already been remapped to
// a gl_Position variable. Substitute the type.
const Type* param_type = ty_.Vector(ty_.F32(), 4);
ast::AttributeList out_decos{
create<ast::BuiltinAttribute>(source, ast::Builtin::kPosition)};
const auto var_name = namer_.GetName(var_id);
return_members.push_back(
builder_.Member(var_name, param_type->Build(builder_), out_decos));
return_exprs.push_back(builder_.Expr(var_name));
} else {
const auto* var = def_use_mgr_->GetDef(var_id);
TINT_ASSERT(Reader, var != nullptr);
TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
const Type* store_type = GetVariableStoreType(*var);
const Type* forced_member_type = store_type;
ast::AttributeList out_decos;
if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_member_type,
&out_decos, true)) {
// This occurs, and is not an error, for the PointSize builtin.
if (!success()) {
// But exit early if an error was logged.
return false;
}
continue;
}
const auto var_name = namer_.GetName(var_id);
bool ok = true;
if (HasBuiltinSampleMask(out_decos)) {
// In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a
// scalar. Use the first element only.
auto* sample_mask_array_type =
store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(Reader, sample_mask_array_type);
ok = EmitPipelineOutput(var_name, store_type, &out_decos, {0},
sample_mask_array_type->type, forced_member_type,
&return_members, &return_exprs);
} else {
// The normal path.
ok = EmitPipelineOutput(var_name, store_type, &out_decos, {}, store_type,
forced_member_type, &return_members, &return_exprs);
}
if (!ok) {
return false;
}
}
}
if (return_members.empty()) {
// This can occur if only the PointSize member is accessed, because we
// never emit it.
return_type = ty_.Void()->Build(builder_);
} else {
// Create and register the result type.
auto* str = create<ast::Struct>(Source{}, return_struct_sym, return_members,
ast::AttributeList{});
parser_impl_.AddTypeDecl(return_struct_sym, str);
return_type = builder_.ty.Of(str);
// Add the return-value statement.
stmts.push_back(create<ast::ReturnStatement>(
source, builder_.Construct(source, return_type, std::move(return_exprs))));
}
}
auto* body = create<ast::BlockStatement>(source, stmts);
ast::AttributeList fn_attrs;
fn_attrs.emplace_back(create<ast::StageAttribute>(source, ep_info_->stage));
if (ep_info_->stage == ast::PipelineStage::kCompute) {
auto& size = ep_info_->workgroup_size;
if (size.x != 0 && size.y != 0 && size.z != 0) {
const ast::Expression* x = builder_.Expr(i32(size.x));
const ast::Expression* y = size.y ? builder_.Expr(i32(size.y)) : nullptr;
const ast::Expression* z = size.z ? builder_.Expr(i32(size.z)) : nullptr;
fn_attrs.emplace_back(create<ast::WorkgroupAttribute>(Source{}, x, y, z));
}
}
builder_.AST().AddFunction(create<ast::Function>(
source, builder_.Symbols().Register(ep_info_->name), std::move(decl.params), return_type,
body, std::move(fn_attrs), ast::AttributeList{}));
return true;
}
bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
if (failed()) {
return false;
}
const std::string name = namer_.Name(function_.result_id());
// Surprisingly, the "type id" on an OpFunction is the result type of the
// function, not the type of the function. This is the one exceptional case
// in SPIR-V where the type ID is not the type of the result ID.
auto* ret_ty = parser_impl_.ConvertType(function_.type_id());
if (failed()) {
return false;
}
if (ret_ty == nullptr) {
return Fail() << "internal error: unregistered return type for function with ID "
<< function_.result_id();
}
ast::VariableList ast_params;
function_.ForEachParam([this, &ast_params](const spvtools::opt::Instruction* param) {
auto* type = parser_impl_.ConvertType(param->type_id());
if (type != nullptr) {
auto* ast_param =
parser_impl_.MakeVariable(param->result_id(), ast::StorageClass::kNone, type, true,
false, nullptr, ast::AttributeList{});
// Parameters are treated as const declarations.
ast_params.emplace_back(ast_param);
// The value is accessible by name.
identifier_types_.emplace(param->result_id(), type);
} else {
// We've already logged an error and emitted a diagnostic. Do nothing
// here.
}
});
if (failed()) {
return false;
}
decl->name = name;
decl->params = std::move(ast_params);
decl->return_type = ret_ty;
decl->attributes.clear();
return success();
}
const Type* FunctionEmitter::GetVariableStoreType(const spvtools::opt::Instruction& var_decl_inst) {
const auto type_id = var_decl_inst.type_id();
// Normally we use the SPIRV-Tools optimizer to manage types.
// But when two struct types have the same member types and decorations,
// but differ only in member names, the two struct types will be
// represented by a single common internal struct type.
// So avoid the optimizer's representation and instead follow the
// SPIR-V instructions themselves.
const auto* ptr_ty = def_use_mgr_->GetDef(type_id);
const auto store_ty_id = ptr_ty->GetSingleWordInOperand(1);
const auto* result = parser_impl_.ConvertType(store_ty_id);
return result;
}
bool FunctionEmitter::EmitBody() {
RegisterBasicBlocks();
if (!TerminatorsAreValid()) {
return false;
}
if (!RegisterMerges()) {
return false;
}
ComputeBlockOrderAndPositions();
if (!VerifyHeaderContinueMergeOrder()) {
return false;
}
if (!LabelControlFlowConstructs()) {
return false;
}
if (!FindSwitchCaseHeaders()) {
return false;
}
if (!ClassifyCFGEdges()) {
return false;
}
if (!FindIfSelectionInternalHeaders()) {
return false;
}
if (!RegisterSpecialBuiltInVariables()) {
return false;
}
if (!RegisterLocallyDefinedValues()) {
return false;
}
FindValuesNeedingNamedOrHoistedDefinition();
if (!EmitFunctionVariables()) {
return false;
}
if (!EmitFunctionBodyStatements()) {
return false;
}
return success();
}
void FunctionEmitter::RegisterBasicBlocks() {
for (auto& block : function_) {
block_info_[block.id()] = std::make_unique<BlockInfo>(block);
}
}
bool FunctionEmitter::TerminatorsAreValid() {
if (failed()) {
return false;
}
const auto entry_id = function_.begin()->id();
for (const auto& block : function_) {
if (!block.terminator()) {
return Fail() << "Block " << block.id() << " has no terminator";
}
}
for (const auto& block : function_) {
block.WhileEachSuccessorLabel([this, &block, entry_id](const uint32_t succ_id) -> bool {
if (succ_id == entry_id) {
return Fail() << "Block " << block.id() << " branches to function entry block "
<< entry_id;
}
if (!GetBlockInfo(succ_id)) {
return Fail() << "Block " << block.id() << " in function "
<< function_.DefInst().result_id() << " branches to " << succ_id
<< " which is not a block in the function";
}
return true;
});
}
return success();
}
bool FunctionEmitter::RegisterMerges() {
if (failed()) {
return false;
}
const auto entry_id = function_.begin()->id();
for (const auto& block : function_) {
const auto block_id = block.id();
auto* block_info = GetBlockInfo(block_id);
if (!block_info) {
return Fail() << "internal error: block " << block_id
<< " missing; blocks should already "
"have been registered";
}
if (const auto* inst = block.GetMergeInst()) {
auto terminator_opcode = block.terminator()->opcode();
switch (inst->opcode()) {
case SpvOpSelectionMerge:
if ((terminator_opcode != SpvOpBranchConditional) &&
(terminator_opcode != SpvOpSwitch)) {
return Fail() << "Selection header " << block_id
<< " does not end in an OpBranchConditional or "
"OpSwitch instruction";
}
break;
case SpvOpLoopMerge:
if ((terminator_opcode != SpvOpBranchConditional) &&
(terminator_opcode != SpvOpBranch)) {
return Fail() << "Loop header " << block_id
<< " does not end in an OpBranch or "
"OpBranchConditional instruction";
}
break;
default:
break;
}
const uint32_t header = block.id();
auto* header_info = block_info;
const uint32_t merge = inst->GetSingleWordInOperand(0);
auto* merge_info = GetBlockInfo(merge);
if (!merge_info) {
return Fail() << "Structured header block " << header
<< " declares invalid merge block " << merge;
}
if (merge == header) {
return Fail() << "Structured header block " << header
<< " cannot be its own merge block";
}
if (merge_info->header_for_merge) {
return Fail() << "Block " << merge
<< " declared as merge block for more than one header: "
<< merge_info->header_for_merge << ", " << header;
}
merge_info->header_for_merge = header;
header_info->merge_for_header = merge;
if (inst->opcode() == SpvOpLoopMerge) {
if (header == entry_id) {
return Fail() << "Function entry block " << entry_id
<< " cannot be a loop header";
}
const uint32_t ct = inst->GetSingleWordInOperand(1);
auto* ct_info = GetBlockInfo(ct);
if (!ct_info) {
return Fail() << "Structured header " << header
<< " declares invalid continue target " << ct;
}
if (ct == merge) {
return Fail() << "Invalid structured header block " << header
<< ": declares block " << ct
<< " as both its merge block and continue target";
}
if (ct_info->header_for_continue) {
return Fail() << "Block " << ct
<< " declared as continue target for more than one header: "
<< ct_info->header_for_continue << ", " << header;
}
ct_info->header_for_continue = header;
header_info->continue_for_header = ct;
}
}
// Check single-block loop cases.
bool is_single_block_loop = false;
block_info->basic_block->ForEachSuccessorLabel(
[&is_single_block_loop, block_id](const uint32_t succ) {
if (block_id == succ)
is_single_block_loop = true;
});
const auto ct = block_info->continue_for_header;
block_info->is_continue_entire_loop = ct == block_id;
if (is_single_block_loop && !block_info->is_continue_entire_loop) {
return Fail() << "Block " << block_id
<< " branches to itself but is not its own continue target";
}
// It's valid for a the header of a multi-block loop header to declare
// itself as its own continue target.
}
return success();
}
void FunctionEmitter::ComputeBlockOrderAndPositions() {
block_order_ = StructuredTraverser(function_).ReverseStructuredPostOrder();
for (uint32_t i = 0; i < block_order_.size(); ++i) {
GetBlockInfo(block_order_[i])->pos = i;
}
// The invalid block position is not the position of any block that is in the
// order.
assert(block_order_.size() <= kInvalidBlockPos);
}
bool FunctionEmitter::VerifyHeaderContinueMergeOrder() {
// Verify interval rules for a structured header block:
//
// If the CFG satisfies structured control flow rules, then:
// If header H is reachable, then the following "interval rules" hold,
// where M(H) is H's merge block, and CT(H) is H's continue target:
//
// Pos(H) < Pos(M(H))
//
// If CT(H) exists, then:
// Pos(H) <= Pos(CT(H))
// Pos(CT(H)) < Pos(M)
//
for (auto block_id : block_order_) {
const auto* block_info = GetBlockInfo(block_id);
const auto merge = block_info->merge_for_header;
if (merge == 0) {
continue;
}
// This is a header.
const auto header = block_id;
const auto* header_info = block_info;
const auto header_pos = header_info->pos;
const auto merge_pos = GetBlockInfo(merge)->pos;
// Pos(H) < Pos(M(H))
// Note: When recording merges we made sure H != M(H)
if (merge_pos <= header_pos) {
return Fail() << "Header " << header << " does not strictly dominate its merge block "
<< merge;
// TODO(dneto): Report a path from the entry block to the merge block
// without going through the header block.
}
const auto ct = block_info->continue_for_header;
if (ct == 0) {
continue;
}
// Furthermore, this is a loop header.
const auto* ct_info = GetBlockInfo(ct);
const auto ct_pos = ct_info->pos;
// Pos(H) <= Pos(CT(H))
if (ct_pos < header_pos) {
Fail() << "Loop header " << header << " does not dominate its continue target " << ct;
}
// Pos(CT(H)) < Pos(M(H))
// Note: When recording merges we made sure CT(H) != M(H)
if (merge_pos <= ct_pos) {
return Fail() << "Merge block " << merge << " for loop headed at block " << header
<< " appears at or before the loop's continue "
"construct headed by "
"block "
<< ct;
}
}
return success();
}
bool FunctionEmitter::LabelControlFlowConstructs() {
// Label each block in the block order with its nearest enclosing structured
// control flow construct. Populates the |construct| member of BlockInfo.
// Keep a stack of enclosing structured control flow constructs. Start
// with the synthetic construct representing the entire function.
//
// Scan from left to right in the block order, and check conditions
// on each block in the following order:
//
// a. When you reach a merge block, the top of the stack should
// be the associated header. Pop it off.
// b. When you reach a header, push it on the stack.
// c. When you reach a continue target, push it on the stack.
// (A block can be both a header and a continue target.)
// c. When you reach a block with an edge branching backward (in the
// structured order) to block T:
// T should be a loop header, and the top of the stack should be a
// continue target associated with T.
// This is the end of the continue construct. Pop the continue
// target off the stack.
//
// Note: A loop header can declare itself as its own continue target.
//
// Note: For a single-block loop, that block is a header, its own
// continue target, and its own backedge block.
//
// Note: We pop the merge off first because a merge block that marks
// the end of one construct can be a single-block loop. So that block
// is a merge, a header, a continue target, and a backedge block.
// But we want to finish processing of the merge before dealing with
// the loop.
//
// In the same scan, mark each basic block with the nearest enclosing
// header: the most recent header for which we haven't reached its merge
// block. Also mark the the most recent continue target for which we
// haven't reached the backedge block.
TINT_ASSERT(Reader, block_order_.size() > 0);
constructs_.clear();
const auto entry_id = block_order_[0];
// The stack of enclosing constructs.
std::vector<Construct*> enclosing;
// Creates a control flow construct and pushes it onto the stack.
// Its parent is the top of the stack, or nullptr if the stack is empty.
// Returns the newly created construct.
auto push_construct = [this, &enclosing](size_t depth, Construct::Kind k, uint32_t begin_id,
uint32_t end_id) -> Construct* {
const auto begin_pos = GetBlockInfo(begin_id)->pos;
const auto end_pos =
end_id == 0 ? uint32_t(block_order_.size()) : GetBlockInfo(end_id)->pos;
const auto* parent = enclosing.empty() ? nullptr : enclosing.back();
auto scope_end_pos = end_pos;
// A loop construct is added right after its associated continue construct.
// In that case, adjust the parent up.
if (k == Construct::kLoop) {
TINT_ASSERT(Reader, parent);
TINT_ASSERT(Reader, parent->kind == Construct::kContinue);
scope_end_pos = parent->end_pos;
parent = parent->parent;
}
constructs_.push_back(std::make_unique<Construct>(parent, static_cast<int>(depth), k,
begin_id, end_id, begin_pos, end_pos,
scope_end_pos));
Construct* result = constructs_.back().get();
enclosing.push_back(result);
return result;
};
// Make a synthetic kFunction construct to enclose all blocks in the function.
push_construct(0, Construct::kFunction, entry_id, 0);
// The entry block can be a selection construct, so be sure to process
// it anyway.
for (uint32_t i = 0; i < block_order_.size(); ++i) {
const auto block_id = block_order_[i];
TINT_ASSERT(Reader, block_id > 0);
auto* block_info = GetBlockInfo(block_id);
TINT_ASSERT(Reader, block_info);
if (enclosing.empty()) {
return Fail() << "internal error: too many merge blocks before block " << block_id;
}
const Construct* top = enclosing.back();
while (block_id == top->end_id) {
// We've reached a predeclared end of the construct. Pop it off the
// stack.
enclosing.pop_back();
if (enclosing.empty()) {
return Fail() << "internal error: too many merge blocks before block " << block_id;
}
top = enclosing.back();
}
const auto merge = block_info->merge_for_header;
if (merge != 0) {
// The current block is a header.
const auto header = block_id;
const auto* header_info = block_info;
const auto depth = 1 + top->depth;
const auto ct = header_info->continue_for_header;
if (ct != 0) {
// The current block is a loop header.
// We should see the continue construct after the loop construct, so
// push the loop construct last.
// From the interval rule, the continue construct consists of blocks
// in the block order, starting at the continue target, until just
// before the merge block.
top = push_construct(depth, Construct::kContinue, ct, merge);
// A loop header that is its own continue target will have an
// empty loop construct. Only create a loop construct when
// the continue target is *not* the same as the loop header.
if (header != ct) {
// From the interval rule, the loop construct consists of blocks
// in the block order, starting at the header, until just
// before the continue target.
top = push_construct(depth, Construct::kLoop, header, ct);
// If the loop header branches to two different blocks inside the loop
// construct, then the loop body should be modeled as an if-selection
// construct
std::vector<uint32_t> targets;
header_info->basic_block->ForEachSuccessorLabel(
[&targets](const uint32_t target) { targets.push_back(target); });
if ((targets.size() == 2u) && targets[0] != targets[1]) {
const auto target0_pos = GetBlockInfo(targets[0])->pos;
const auto target1_pos = GetBlockInfo(targets[1])->pos;
if (top->ContainsPos(target0_pos) && top->ContainsPos(target1_pos)) {
// Insert a synthetic if-selection
top = push_construct(depth + 1, Construct::kIfSelection, header, ct);
}
}
}
} else {
// From the interval rule, the selection construct consists of blocks
// in the block order, starting at the header, until just before the
// merge block.
const auto branch_opcode = header_info->basic_block->terminator()->opcode();
const auto kind = (branch_opcode == SpvOpBranchConditional)
? Construct::kIfSelection
: Construct::kSwitchSelection;
top = push_construct(depth, kind, header, merge);
}
}
TINT_ASSERT(Reader, top);
block_info->construct = top;
}
// At the end of the block list, we should only have the kFunction construct
// left.
if (enclosing.size() != 1) {
return Fail() << "internal error: unbalanced structured constructs when "
"labeling structured constructs: ended with "
<< enclosing.size() - 1 << " unterminated constructs";
}
const auto* top = enclosing[0];
if (top->kind != Construct::kFunction || top->depth != 0) {
return Fail() << "internal error: outermost construct is not a function?!";
}
return success();
}
bool FunctionEmitter::FindSwitchCaseHeaders() {
if (failed()) {
return false;
}
for (auto& construct : constructs_) {
if (construct->kind != Construct::kSwitchSelection) {
continue;
}
const auto* branch = GetBlockInfo(construct->begin_id)->basic_block->terminator();
// Mark the default block
const auto default_id = branch->GetSingleWordInOperand(1);
auto* default_block = GetBlockInfo(default_id);
// A default target can't be a backedge.
if (construct->begin_pos >= default_block->pos) {
// An OpSwitch must dominate its cases. Also, it can't be a self-loop
// as that would be a backedge, and backedges can only target a loop,
// and loops use an OpLoopMerge instruction, which can't precede an
// OpSwitch.
return Fail() << "Switch branch from block " << construct->begin_id
<< " to default target block " << default_id << " can't be a back-edge";
}
// A default target can be the merge block, but can't go past it.
if (construct->end_pos < default_block->pos) {
return Fail() << "Switch branch from block " << construct->begin_id
<< " to default block " << default_id
<< " escapes the selection construct";
}
if (default_block->default_head_for) {
// An OpSwitch must dominate its cases, including the default target.
return Fail() << "Block " << default_id
<< " is declared as the default target for two OpSwitch "
"instructions, at blocks "
<< default_block->default_head_for->begin_id << " and "
<< construct->begin_id;
}
if ((default_block->header_for_merge != 0) &&
(default_block->header_for_merge != construct->begin_id)) {
// The switch instruction for this default block is an alternate path to
// the merge block, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << default_block->id
<< " is the default block for switch-selection header "
<< construct->begin_id << " and also the merge block for "
<< default_block->header_for_merge << " (violates dominance rule)";
}
default_block->default_head_for = construct.get();
default_block->default_is_merge = default_block->pos == construct->end_pos;
// Map a case target to the list of values selecting that case.
std::unordered_map<uint32_t, std::vector<uint64_t>> block_to_values;
std::vector<uint32_t> case_targets;
std::unordered_set<uint64_t> case_values;
// Process case targets.
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
if (case_values.count(value)) {
return Fail() << "Duplicate case value " << value << " in OpSwitch in block "
<< construct->begin_id;
}
case_values.insert(value);
if (block_to_values.count(case_target_id) == 0) {
case_targets.push_back(case_target_id);
}
block_to_values[case_target_id].push_back(value);
}
for (uint32_t case_target_id : case_targets) {
auto* case_block = GetBlockInfo(case_target_id);
case_block->case_values =
std::make_unique<std::vector<uint64_t>>(std::move(block_to_values[case_target_id]));
// A case target can't be a back-edge.
if (construct->begin_pos >= case_block->pos) {
// An OpSwitch must dominate its cases. Also, it can't be a self-loop
// as that would be a backedge, and backedges can only target a loop,
// and loops use an OpLoopMerge instruction, which can't preceded an
// OpSwitch.
return Fail() << "Switch branch from block " << construct->begin_id
<< " to case target block " << case_target_id
<< " can't be a back-edge";
}
// A case target can be the merge block, but can't go past it.
if (construct->end_pos < case_block->pos) {
return Fail() << "Switch branch from block " << construct->begin_id
<< " to case target block " << case_target_id
<< " escapes the selection construct";
}
if (case_block->header_for_merge != 0 &&
case_block->header_for_merge != construct->begin_id) {
// The switch instruction for this case block is an alternate path to
// the merge block, and hence the merge block is not dominated by its
// own (different) header.
return Fail() << "Block " << case_block->id
<< " is a case block for switch-selection header "
<< construct->begin_id << " and also the merge block for "
<< case_block->header_for_merge << " (violates dominance rule)";
}
// Mark the target as a case target.
if (case_block->case_head_for) {
// An OpSwitch must dominate its cases.
return Fail() << "Block " << case_target_id
<< " is declared as the switch case target for two OpSwitch "
"instructions, at blocks "
<< case_block->case_head_for->begin_id << " and "
<< construct->begin_id;
}
case_block->case_head_for = construct.get();
}
}
return success();
}
BlockInfo* FunctionEmitter::HeaderIfBreakable(const Construct* c) {
if (c == nullptr) {
return nullptr;
}
switch (c->kind) {
case Construct::kLoop:
case Construct::kSwitchSelection:
return GetBlockInfo(c->begin_id);
case Construct::kContinue: {
const auto* continue_target = GetBlockInfo(c->begin_id);
return GetBlockInfo(continue_target->header_for_continue);
}
default:
break;
}
return nullptr;
}
const Construct* FunctionEmitter::SiblingLoopConstruct(const Construct* c) const {
if (c == nullptr || c->kind != Construct::kContinue) {
return nullptr;
}
const uint32_t continue_target_id = c->begin_id;
const auto* continue_target = GetBlockInfo(continue_target_id);
const uint32_t header_id = continue_target->header_for_continue;
if (continue_target_id == header_id) {
// The continue target is the whole loop.
return nullptr;
}
const auto* candidate = GetBlockInfo(header_id)->construct;
// Walk up the construct tree until we hit the loop. In future
// we might handle the corner case where the same block is both a
// loop header and a selection header. For example, where the
// loop header block has a conditional branch going to distinct
// targets inside the loop body.
while (candidate && candidate->kind != Construct::kLoop) {
candidate = candidate->parent;
}
return candidate;
}
bool FunctionEmitter::ClassifyCFGEdges() {
if (failed()) {
return false;
}
// Checks validity of CFG edges leaving each basic block. This implicitly
// checks dominance rules for headers and continue constructs.
//
// For each branch encountered, classify each edge (S,T) as:
// - a back-edge
// - a structured exit (specific ways of branching to enclosing construct)
// - a normal (forward) edge, either natural control flow or a case
// fallthrough
//
// If more than one block is targeted by a normal edge, then S must be a
// structured header.
//
// Term: NEC(B) is the nearest enclosing construct for B.
//
// If edge (S,T) is a normal edge, and NEC(S) != NEC(T), then
// T is the header block of its NEC(T), and
// NEC(S) is the parent of NEC(T).
for (const auto src : block_order_) {
TINT_ASSERT(Reader, src > 0);
auto* src_info = GetBlockInfo(src);
TINT_ASSERT(Reader, src_info);
const auto src_pos = src_info->pos;
const auto& src_construct = *(src_info->construct);
// Compute the ordered list of unique successors.
std::vector<uint32_t> successors;
{
std::unordered_set<uint32_t> visited;
src_info->basic_block->ForEachSuccessorLabel(
[&successors, &visited](const uint32_t succ) {
if (visited.count(succ) == 0) {
successors.push_back(succ);
visited.insert(succ);
}
});
}
// There should only be one backedge per backedge block.
uint32_t num_backedges = 0;
// Track destinations for normal forward edges, either kForward
// or kCaseFallThrough. These count toward the need
// to have a merge instruction. We also track kIfBreak edges
// because when used with normal forward edges, we'll need
// to generate a flow guard variable.
std::vector<uint32_t> normal_forward_edges;
std::vector<uint32_t> if_break_edges;
if (successors.empty() && src_construct.enclosing_continue) {
// Kill and return are not allowed in a continue construct.
return Fail() << "Invalid function exit at block " << src
<< " from continue construct starting at "
<< src_construct.enclosing_continue->begin_id;
}
for (const auto dest : successors) {
const auto* dest_info = GetBlockInfo(dest);
// We've already checked terminators are valid.
TINT_ASSERT(Reader, dest_info);
const auto dest_pos = dest_info->pos;
// Insert the edge kind entry and keep a handle to update
// its classification.
EdgeKind& edge_kind = src_info->succ_edge[dest];
if (src_pos >= dest_pos) {
// This is a backedge.
edge_kind = EdgeKind::kBack;
num_backedges++;
const auto* continue_construct = src_construct.enclosing_continue;
if (!continue_construct) {
return Fail() << "Invalid backedge (" << src << "->" << dest << "): " << src
<< " is not in a continue construct";
}
if (src_pos != continue_construct->end_pos - 1) {
return Fail() << "Invalid exit (" << src << "->" << dest
<< ") from continue construct: " << src
<< " is not the last block in the continue construct "
"starting at "
<< src_construct.begin_id << " (violates post-dominance rule)";
}
const auto* ct_info = GetBlockInfo(continue_construct->begin_id);
TINT_ASSERT(Reader, ct_info);
if (ct_info->header_for_continue != dest) {
return Fail() << "Invalid backedge (" << src << "->" << dest
<< "): does not branch to the corresponding loop header, "
"expected "
<< ct_info->header_for_continue;
}
} else {
// This is a forward edge.
// For now, classify it that way, but we might update it.
edge_kind = EdgeKind::kForward;
// Exit from a continue construct can only be from the last block.
const auto* continue_construct = src_construct.enclosing_continue;
if (continue_construct != nullptr) {
if (continue_construct->ContainsPos(src_pos) &&
!continue_construct->ContainsPos(dest_pos) &&
(src_pos != continue_construct->end_pos - 1)) {
return Fail()
<< "Invalid exit (" << src << "->" << dest
<< ") from continue construct: " << src
<< " is not the last block in the continue construct "
"starting at "
<< continue_construct->begin_id << " (violates post-dominance rule)";
}
}
// Check valid structured exit cases.
if (edge_kind == EdgeKind::kForward) {
// Check for a 'break' from a loop or from a switch.
const auto* breakable_header =
HeaderIfBreakable(src_construct.enclosing_loop_or_continue_or_switch);
if (breakable_header != nullptr) {
if (dest == breakable_header->merge_for_header) {
// It's a break.
edge_kind =
(breakable_header->construct->kind == Construct::kSwitchSelection)
? EdgeKind::kSwitchBreak
: EdgeKind::kLoopBreak;
}
}
}
if (edge_kind == EdgeKind::kForward) {
// Check for a 'continue' from within a loop.
const auto* loop_header = HeaderIfBreakable(src_construct.enclosing_loop);
if (loop_header != nullptr) {
if (dest == loop_header->continue_for_header) {
// It's a continue.
edge_kind = EdgeKind::kLoopContinue;
}
}
}
if (edge_kind == EdgeKind::kForward) {
const auto& header_info = *GetBlockInfo(src_construct.begin_id);
if (dest == header_info.merge_for_header) {
// Branch to construct's merge block. The loop break and
// switch break cases have already been covered.
edge_kind = EdgeKind::kIfBreak;
}
}
// A forward edge into a case construct that comes from something
// other than the OpSwitch is actually a fallthrough.
if (edge_kind == EdgeKind::kForward) {
const auto* switch_construct =
(dest_info->case_head_for ? dest_info->case_head_for
: dest_info->default_head_for);
if (switch_construct != nullptr) {
if (src != switch_construct->begin_id) {
edge_kind = EdgeKind::kCaseFallThrough;
}
}
}
// The edge-kind has been finalized.
if ((edge_kind == EdgeKind::kForward) ||
(edge_kind == EdgeKind::kCaseFallThrough)) {
normal_forward_edges.push_back(dest);
}
if (edge_kind == EdgeKind::kIfBreak) {
if_break_edges.push_back(dest);
}
if ((edge_kind == EdgeKind::kForward) ||
(edge_kind == EdgeKind::kCaseFallThrough)) {
// Check for an invalid forward exit out of this construct.
if (dest_info->pos > src_construct.end_pos) {
// In most cases we're bypassing the merge block for the source
// construct.
auto end_block = src_construct.end_id;
const char* end_block_desc = "merge block";
if (src_construct.kind == Construct::kLoop) {
// For a loop construct, we have two valid places to go: the
// continue target or the merge for the loop header, which is
// further down.
const auto loop_merge =
GetBlockInfo(src_construct.begin_id)->merge_for_header;
if (dest_info->pos >= GetBlockInfo(loop_merge)->pos) {
// We're bypassing the loop's merge block.
end_block = loop_merge;
} else {
// We're bypassing the loop's continue target, and going into
// the middle of the continue construct.
end_block_desc = "continue target";
}
}
return Fail() << "Branch from block " << src << " to block " << dest
<< " is an invalid exit from construct starting at block "
<< src_construct.begin_id << "; branch bypasses "
<< end_block_desc << " " << end_block;
}
// Check dominance.
// Look for edges that violate the dominance condition: a branch
// from X to Y where:
// If Y is in a nearest enclosing continue construct headed by
// CT:
// Y is not CT, and
// In the structured order, X appears before CT order or
// after CT's backedge block.
// Otherwise, if Y is in a nearest enclosing construct
// headed by H:
// Y is not H, and
// In the structured order, X appears before H or after H's
// merge block.
const auto& dest_construct = *(dest_info->construct);
if (dest != dest_construct.begin_id && !dest_construct.ContainsPos(src_pos)) {
return Fail()
<< "Branch from " << src << " to " << dest << " bypasses "
<< (dest_construct.kind == Construct::kContinue ? "continue target "
: "header ")
<< dest_construct.begin_id << " (dominance rule violated)";
}
}
} // end forward edge
} // end successor
if (num_backedges > 1) {
return Fail() << "Block " << src << " has too many backedges: " << num_backedges;
}
if ((normal_forward_edges.size() > 1) && (src_info->merge_for_header == 0)) {
return Fail() << "Control flow diverges at block " << src << " (to "
<< normal_forward_edges[0] << ", " << normal_forward_edges[1]
<< ") but it is not a structured header (it has no merge "
"instruction)";
}
if ((normal_forward_edges.size() + if_break_edges.size() > 1) &&
(src_info->merge_for_header == 0)) {
// There is a branch to the merge of an if-selection combined
// with an other normal forward branch. Control within the
// if-selection needs to be gated by a flow predicate.
for (auto if_break_dest : if_break_edges) {
auto* head_info = GetBlockInfo(GetBlockInfo(if_break_dest)->header_for_merge);
// Generate a guard name, but only once.
if (head_info->flow_guard_name.empty()) {
const std::string guard = "guard" + std::to_string(head_info->id);
head_info->flow_guard_name = namer_.MakeDerivedName(guard);
}
}
}
}
return success();
}
bool FunctionEmitter::FindIfSelectionInternalHeaders() {
if (failed()) {
return false;
}
for (auto& construct : constructs_) {
if (construct->kind != Construct::kIfSelection) {
continue;
}
auto* if_header_info = GetBlockInfo(construct->begin_id);
const auto* branch = if_header_info->basic_block->terminator();
const auto true_head = branch->GetSingleWordInOperand(1);
const auto false_head = branch->GetSingleWordInOperand(2);
auto* true_head_info = GetBlockInfo(true_head);
auto* false_head_info = GetBlockInfo(false_head);
const auto true_head_pos = true_head_info->pos;
const auto false_head_pos = false_head_info->pos;
const bool contains_true = construct->ContainsPos(true_head_pos);
const bool contains_false = construct->ContainsPos(false_head_pos);
// The cases for each edge are:
// - kBack: invalid because it's an invalid exit from the selection
// - kSwitchBreak ; record this for later special processing
// - kLoopBreak ; record this for later special processing
// - kLoopContinue ; record this for later special processing
// - kIfBreak; normal case, may require a guard variable.
// - kFallThrough; invalid exit from the selection
// - kForward; normal case
if_header_info->true_kind = if_header_info->succ_edge[true_head];
if_header_info->false_kind = if_header_info->succ_edge[false_head];
if (contains_true) {
if_header_info->true_head = true_head;
}
if (contains_false) {
if_header_info->false_head = false_head;
}
if (contains_true && (true_head_info->header_for_merge != 0) &&
(true_head_info->header_for_merge != construct->begin_id)) {
// The OpBranchConditional instruction for the true head block is an
// alternate path to the merge block of a construct nested inside the
// selection, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << true_head << " is the true branch for if-selection header "
<< construct->begin_id << " and also the merge block for header block "
<< true_head_info->header_for_merge << " (violates dominance rule)";
}
if (contains_false && (false_head_info->header_for_merge != 0) &&
(false_head_info->header_for_merge != construct->begin_id)) {
// The OpBranchConditional instruction for the false head block is an
// alternate path to the merge block of a construct nested inside the
// selection, and hence the merge block is not dominated by its own
// (different) header.
return Fail() << "Block " << false_head
<< " is the false branch for if-selection header " << construct->begin_id
<< " and also the merge block for header block "
<< false_head_info->header_for_merge << " (violates dominance rule)";
}
if (contains_true && contains_false && (true_head_pos != false_head_pos)) {
// This construct has both a "then" clause and an "else" clause.
//
// We have this structure:
//
// Option 1:
//
// * condbranch
// * true-head (start of then-clause)
// ...
// * end-then-clause
// * false-head (start of else-clause)
// ...
// * end-false-clause
// * premerge-head
// ...
// * selection merge
//
// Option 2:
//
// * condbranch
// * true-head (start of then-clause)
// ...
// * end-then-clause
// * false-head (start of else-clause) and also premerge-head
// ...
// * end-false-clause
// * selection merge
//
// Option 3:
//
// * condbranch
// * false-head (start of else-clause)
// ...
// * end-else-clause
// * true-head (start of then-clause) and also premerge-head
// ...
// * end-then-clause
// * selection merge
//
// The premerge-head exists if there is a kForward branch from the end
// of the first clause to a block within the surrounding selection.
// The first clause might be a then-clause or an else-clause.
const auto second_head = std::max(true_head_pos, false_head_pos);
const auto end_first_clause_pos = second_head - 1;
TINT_ASSERT(Reader, end_first_clause_pos < block_order_.size());
const auto end_first_clause = block_order_[end_first_clause_pos];
uint32_t premerge_id = 0;
uint32_t if_break_id = 0;
for (auto& then_succ_iter : GetBlockInfo(end_first_clause)->succ_edge) {
const uint32_t dest_id = then_succ_iter.first;
const auto edge_kind = then_succ_iter.second;
switch (edge_kind) {
case EdgeKind::kIfBreak:
if_break_id = dest_id;
break;
case EdgeKind::kForward: {
if (construct->ContainsPos(GetBlockInfo(dest_id)->pos)) {
// It's a premerge.
if (premerge_id != 0) {
// TODO(dneto): I think this is impossible to trigger at this
// point in the flow. It would require a merge instruction to
// get past the check of "at-most-one-forward-edge".
return Fail()
<< "invalid structure: then-clause headed by block "
<< true_head << " ending at block " << end_first_clause
<< " has two forward edges to within selection"
<< " going to " << premerge_id << " and " << dest_id;
}
premerge_id = dest_id;
auto* dest_block_info = GetBlockInfo(dest_id);
if_header_info->premerge_head = dest_id;
if (dest_block_info->header_for_merge != 0) {
// Premerge has two edges coming into it, from the then-clause
// and the else-clause. It's also, by construction, not the
// merge block of the if-selection. So it must not be a merge
// block itself. The OpBranchConditional instruction for the
// false head block is an alternate path to the merge block, and
// hence the merge block is not dominated by its own (different)
// header.
return Fail()
<< "Block " << premerge_id << " is the merge block for "
<< dest_block_info->header_for_merge
<< " but has alternate paths reaching it, starting from"
<< " blocks " << true_head << " and " << false_head
<< " which are the true and false branches for the"
<< " if-selection header block " << construct->begin_id
<< " (violates dominance rule)";
}
}
break;
}
default:
break;
}
}
if (if_break_id != 0 && premerge_id != 0) {
return Fail() << "Block " << end_first_clause << " in if-selection headed at block "
<< construct->begin_id << " branches to both the merge block "
<< if_break_id << " and also to block " << premerge_id
<< " later in the selection";
}
}
}
return success();
}
bool FunctionEmitter::EmitFunctionVariables() {
if (failed()) {
return false;
}
for (auto& inst : *function_.entry()) {
if (inst.opcode() != SpvOpVariable) {
continue;
}
auto* var_store_type = GetVariableStoreType(inst);
if (failed()) {
return false;
}
const ast::Expression* constructor = nullptr;
if (inst.NumInOperands() > 1) {
// SPIR-V initializers are always constants.
// (OpenCL also allows the ID of an OpVariable, but we don't handle that
// here.)
constructor = parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)).expr;
if (!constructor) {
return false;
}
}
auto* var =
parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, var_store_type,
false, false, constructor, ast::AttributeList{});
auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
AddStatement(var_decl_stmt);
auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone);
identifier_types_.emplace(inst.result_id(), var_type);
}
return success();
}
TypedExpression FunctionEmitter::AddressOfIfNeeded(TypedExpression expr,
const spvtools::opt::Instruction* inst) {
if (inst && expr) {
if (auto* spirv_type = type_mgr_->GetType(inst->type_id())) {
if (expr.type->Is<Reference>() && spirv_type->AsPointer()) {
return AddressOf(expr);
}
}
}
return expr;
}
TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
if (failed()) {
return {};
}
switch (GetSkipReason(id)) {
case SkipReason::kDontSkip:
break;
case SkipReason::kOpaqueObject:
Fail() << "internal error: unhandled use of opaque object with ID: " << id;
return {};
case SkipReason::kSinkPointerIntoUse: {
// Replace the pointer with its source reference expression.
auto source_expr = GetDefInfo(id)->sink_pointer_source_expr;
TINT_ASSERT(Reader, source_expr.type->Is<Reference>());
return source_expr;
}
case SkipReason::kPointSizeBuiltinValue: {
return {ty_.F32(), create<ast::FloatLiteralExpression>(
Source{}, 1.0, ast::FloatLiteralExpression::Suffix::kF)};
}
case SkipReason::kPointSizeBuiltinPointer:
Fail() << "unhandled use of a pointer to the PointSize builtin, with ID: " << id;
return {};
case SkipReason::kSampleMaskInBuiltinPointer:
Fail() << "unhandled use of a pointer to the SampleMask builtin, with ID: " << id;
return {};
case SkipReason::kSampleMaskOutBuiltinPointer: {
// The result type is always u32.
auto name = namer_.Name(sample_mask_out_id);
return TypedExpression{ty_.U32(), create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name))};
}
}
auto type_it = identifier_types_.find(id);
if (type_it != identifier_types_.end()) {
auto name = namer_.Name(id);
auto* type = type_it->second;
return TypedExpression{
type, create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name))};
}
if (parser_impl_.IsScalarSpecConstant(id)) {
auto name = namer_.Name(id);
return TypedExpression{
parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name))};
}
if (singly_used_values_.count(id)) {
auto expr = std::move(singly_used_values_[id]);
singly_used_values_.erase(id);
return expr;
}
const auto* spirv_constant = constant_mgr_->FindDeclaredConstant(id);
if (spirv_constant) {
return parser_impl_.MakeConstantExpression(id);
}
const auto* inst = def_use_mgr_->GetDef(id);
if (inst == nullptr) {
Fail() << "ID " << id << " does not have a defining SPIR-V instruction";
return {};
}
switch (inst->opcode()) {
case SpvOpVariable: {
// This occurs for module-scope variables.
auto name = namer_.Name(inst->result_id());
return TypedExpression{
parser_impl_.ConvertType(inst->type_id(), PtrAs::Ref),
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name))};
}
case SpvOpUndef:
// Substitute a null value for undef.
// This case occurs when OpUndef appears at module scope, as if it were
// a constant.
return parser_impl_.MakeNullExpression(parser_impl_.ConvertType(inst->type_id()));
default:
break;
}
if (const spvtools::opt::BasicBlock* const bb = ir_context_.get_instr_block(id)) {
if (auto* block = GetBlockInfo(bb->id())) {
if (block->pos == kInvalidBlockPos) {
// The value came from a block not in the block order.
// Substitute a null value.
return parser_impl_.MakeNullExpression(parser_impl_.ConvertType(inst->type_id()));
}
}
}
Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint();
return {};
}
bool FunctionEmitter::EmitFunctionBodyStatements() {
// Dump the basic blocks in order, grouped by construct.
// We maintain a stack of StatementBlock objects, where new statements
// are always written to the topmost entry of the stack. By this point in
// processing, we have already recorded the interesting control flow
// boundaries in the BlockInfo and associated Construct objects. As we
// enter a new statement grouping, we push onto the stack, and also schedule
// the statement block's completion and removal at a future block's ID.
// Upon entry, the statement stack has one entry representing the whole
// function.
TINT_ASSERT(Reader, !constructs_.empty());
Construct* function_construct = constructs_[0].get();
TINT_ASSERT(Reader, function_construct != nullptr);
TINT_ASSERT(Reader, function_construct->kind == Construct::kFunction);
// Make the first entry valid by filling in the construct field, which
// had not been computed at the time the entry was first created.
// TODO(dneto): refactor how the first construct is created vs.
// this statements stack entry is populated.
TINT_ASSERT(Reader, statements_stack_.size() == 1);
statements_stack_[0].SetConstruct(function_construct);
for (auto block_id : block_order()) {
if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
return false;
}
}
return success();
}
bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
// Close off previous constructs.
while (!statements_stack_.empty() && (statements_stack_.back().GetEndId() == block_info.id)) {
statements_stack_.back().Finalize(&builder_);
statements_stack_.pop_back();
}
if (statements_stack_.empty()) {
return Fail() << "internal error: statements stack empty at block " << block_info.id;
}
// Enter new constructs.
std::vector<const Construct*> entering_constructs; // inner most comes first
{
auto* here = block_info.construct;
auto* const top_construct = statements_stack_.back().GetConstruct();
while (here != top_construct) {
// Only enter a construct at its header block.
if (here->begin_id == block_info.id) {
entering_constructs.push_back(here);
}
here = here->parent;
}
}
// What constructs can we have entered?
// - It can't be kFunction, because there is only one of those, and it was
// already on the stack at the outermost level.
// - We have at most one of kSwitchSelection, or kLoop because each of those
// is headed by a block with a merge instruction (OpLoopMerge for kLoop,
// and OpSelectionMerge for kSwitchSelection).
// - When there is a kIfSelection, it can't contain another construct,
// because both would have to have their own distinct merge instructions
// and distinct terminators.
// - A kContinue can contain a kContinue
// This is possible in Vulkan SPIR-V, but Tint disallows this by the rule
// that a block can be continue target for at most one header block. See
// test DISABLED_BlockIsContinueForMoreThanOneHeader. If we generalize this,
// then by a dominance argument, the inner loop continue target can only be
// a single-block loop.
// TODO(dneto): Handle this case.
// - If a kLoop is on the outside, its terminator is either:
// - an OpBranch, in which case there is no other construct.
// - an OpBranchConditional, in which case there is either an kIfSelection
// (when both branch targets are different and are inside the loop),
// or no other construct (because the branch targets are the same,
// or one of them is a break or continue).
// - All that's left is a kContinue on the outside, and one of
// kIfSelection, kSwitchSelection, kLoop on the inside.
//
// The kContinue can be the parent of the other. For example, a selection
// starting at the first block of a continue construct.
//
// The kContinue can't be the child of the other because either:
// - The other can't be kLoop because:
// - If the kLoop is for a different loop then the kContinue, then
// the kContinue must be its own loop header, and so the same
// block is two different loops. That's a contradiction.
// - If the kLoop is for a the same loop, then this is a contradiction
// because a kContinue and its kLoop have disjoint block sets.
// - The other construct can't be a selection because:
// - The kContinue construct is the entire loop, i.e. the continue
// target is its own loop header block. But then the continue target
// has an OpLoopMerge instruction, which contradicts this block being
// a selection header.
// - The kContinue is in a multi-block loop that is has a non-empty
// kLoop; and the selection contains the kContinue block but not the
// loop block. That breaks dominance rules. That is, the continue
// target is dominated by that loop header, and so gets found by the
// block traversal on the outside before the selection is found. The
// selection is inside the outer loop.
//
// So we fall into one of the following cases:
// - We are entering 0 or 1 constructs, or
// - We are entering 2 constructs, with the outer one being a kContinue or
// kLoop, the inner one is not a continue.
if (entering_constructs.size() > 2) {
return Fail() << "internal error: bad construct nesting found";
}
if (entering_constructs.size() == 2) {
auto inner_kind = entering_constructs[0]->kind;
auto outer_kind = entering_constructs[1]->kind;
if (outer_kind != Construct::kContinue && outer_kind != Construct::kLoop) {
return Fail() << "internal error: bad construct nesting. Only a Continue "
"or a Loop construct can be outer construct on same block. "
"Got outer kind "
<< int(outer_kind) << " inner kind " << int(inner_kind);
}
if (inner_kind == Construct::kContinue) {
return Fail() << "internal error: unsupported construct nesting: "
"Continue around Continue";
}
if (inner_kind != Construct::kIfSelection && inner_kind != Construct::kSwitchSelection &&
inner_kind != Construct::kLoop) {
return Fail() << "internal error: bad construct nesting. Continue around "
"something other than if, switch, or loop";
}
}
// Enter constructs from outermost to innermost.
// kLoop and kContinue push a new statement-block onto the stack before
// emitting statements in the block.
// kIfSelection and kSwitchSelection emit statements in the block and then
// emit push a new statement-block. Only emit the statements in the block
// once.
// Have we emitted the statements for this block?
bool emitted = false;
// When entering an if-selection or switch-selection, we will emit the WGSL
// construct to cause the divergent branching. But otherwise, we will
// emit a "normal" block terminator, which occurs at the end of this method.
bool has_normal_terminator = true;
for (auto iter = entering_constructs.rbegin(); iter != entering_constructs.rend(); ++iter) {
const Construct* construct = *iter;
switch (construct->kind) {
case Construct::kFunction:
return Fail() << "internal error: nested function construct";
case Construct::kLoop:
if (!EmitLoopStart(construct)) {
return false;
}
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
break;
case Construct::kContinue:
if (block_info.is_continue_entire_loop) {
if (!EmitLoopStart(construct)) {
return false;
}
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
} else {
if (!EmitContinuingStart(construct)) {
return false;
}
}
break;
case Construct::kIfSelection:
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
if (!EmitIfStart(block_info)) {
return false;
}
has_normal_terminator = false;
break;
case Construct::kSwitchSelection:
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
if (!EmitSwitchStart(block_info)) {
return false;
}
has_normal_terminator = false;
break;
}
}
// If we aren't starting or transitioning, then emit the normal
// statements now.
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
if (has_normal_terminator) {
if (!EmitNormalTerminator(block_info)) {
return false;
}
}
return success();
}
bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
// The block is the if-header block. So its construct is the if construct.
auto* construct = block_info.construct;
TINT_ASSERT(Reader, construct->kind == Construct::kIfSelection);
TINT_ASSERT(Reader, construct->begin_id == block_info.id);
const uint32_t true_head = block_info.true_head;
const uint32_t false_head = block_info.false_head;
const uint32_t premerge_head = block_info.premerge_head;
const std::string guard_name = block_info.flow_guard_name;
if (!guard_name.empty()) {
// Declare the guard variable just before the "if", initialized to true.
auto* guard_var = builder_.Var(guard_name, builder_.ty.bool_(), MakeTrue(Source{}));
auto* guard_decl = create<ast::VariableDeclStatement>(Source{}, guard_var);
AddStatement(guard_decl);
}
const auto condition_id = block_info.basic_block->terminator()->GetSingleWordInOperand(0);
auto* cond = MakeExpression(condition_id).expr;
if (!cond) {
return false;
}
// Generate the code for the condition.
auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
// Compute the block IDs that should end the then-clause and the else-clause.
// We need to know where the *emitted* selection should end, i.e. the intended
// merge block id. That should be the current premerge block, if it exists,
// or otherwise the declared merge block.
//
// This is another way to think about it:
// If there is a premerge, then there are three cases:
// - premerge_head is different from the true_head and false_head:
// - Premerge comes last. In effect, move the selection merge up
// to where the premerge begins.
// - premerge_head is the same as the false_head
// - This is really an if-then without an else clause.
// Move the merge up to where the premerge is.
// - premerge_head is the same as the true_head
// - This is really an if-else without an then clause.
// Emit it as: if (cond) {} else {....}
// Move the merge up to where the premerge is.
const uint32_t intended_merge = premerge_head ? premerge_head : construct->end_id;
// then-clause:
// If true_head exists:
// spans from true head to the earlier of the false head (if it exists)
// or the selection merge.
// Otherwise:
// ends at from the false head (if it exists), otherwise the selection
// end.
const uint32_t then_end = false_head ? false_head : intended_merge;
// else-clause:
// ends at the premerge head (if it exists) or at the selection end.
const uint32_t else_end = premerge_head ? premerge_head : intended_merge;
const bool true_is_break = (block_info.true_kind == EdgeKind::kSwitchBreak) ||
(block_info.true_kind == EdgeKind::kLoopBreak);
const bool false_is_break = (block_info.false_kind == EdgeKind::kSwitchBreak) ||
(block_info.false_kind == EdgeKind::kLoopBreak);
const bool true_is_continue = block_info.true_kind == EdgeKind::kLoopContinue;
const bool false_is_continue = block_info.false_kind == EdgeKind::kLoopContinue;
// Push statement blocks for the then-clause and the else-clause.
// But make sure we do it in the right order.
auto push_else = [this, builder, else_end, construct, false_is_break, false_is_continue]() {
// Push the else clause onto the stack first.
PushNewStatementBlock(construct, else_end, [=](const ast::StatementList& stmts) {
// Only set the else-clause if there are statements to fill it.
if (!stmts.empty()) {
// The "else" consists of the statement list from the top of
// statements stack, without an "else if" condition.
builder->else_stmt = create<ast::BlockStatement>(Source{}, stmts);
}
});
if (false_is_break) {
AddStatement(create<ast::BreakStatement>(Source{}));
}
if (false_is_continue) {
AddStatement(create<ast::ContinueStatement>(Source{}));
}
};
if (!true_is_break && !true_is_continue &&
(GetBlockInfo(else_end)->pos < GetBlockInfo(then_end)->pos)) {
// Process the else-clause first. The then-clause will be empty so avoid
// pushing onto the stack at all.
push_else();
} else {
// Blocks for the then-clause appear before blocks for the else-clause.
// So push the else-clause handling onto the stack first. The else-clause
// might be empty, but this works anyway.
// Handle the premerge, if it exists.
if (premerge_head) {
// The top of the stack is the statement block that is the parent of the
// if-statement. Adding statements now will place them after that 'if'.
if (guard_name.empty()) {
// We won't have a flow guard for the premerge.
// Insert a trivial if(true) { ... } around the blocks from the
// premerge head until the end of the if-selection. This is needed
// to ensure uniform reconvergence occurs at the end of the if-selection
// just like in the original SPIR-V.
PushTrueGuard(construct->end_id);
} else {
// Add a flow guard around the blocks in the premerge area.
PushGuard(guard_name, construct->end_id);
}
}
push_else();
if (true_head && false_head && !guard_name.empty()) {
// There are non-trivial then and else clauses.
// We have to guard the start of the else.
PushGuard(guard_name, else_end);
}
// Push the then clause onto the stack.
PushNewStatementBlock(construct, then_end, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
if (true_is_break) {
AddStatement(create<ast::BreakStatement>(Source{}));
}
if (true_is_continue) {
AddStatement(create<ast::ContinueStatement>(Source{}));
}
}
return success();
}
bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
// The block is the if-header block. So its construct is the if construct.
auto* construct = block_info.construct;
TINT_ASSERT(Reader, construct->kind == Construct::kSwitchSelection);
TINT_ASSERT(Reader, construct->begin_id == block_info.id);
const auto* branch = block_info.basic_block->terminator();
const auto selector_id = branch->GetSingleWordInOperand(0);
// Generate the code for the selector.
auto selector = MakeExpression(selector_id);
if (!selector) {
return false;
}
// First, push the statement block for the entire switch.
auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
// Grab a pointer to the case list. It will get buried in the statement block
// stack.
PushNewStatementBlock(construct, construct->end_id, nullptr);
// We will push statement-blocks onto the stack to gather the statements in
// the default clause and cases clauses. Determine the list of blocks
// that start each clause.
std::vector<const BlockInfo*> clause_heads;
// Collect the case clauses, even if they are just the merge block.
// First the default clause.
const auto default_id = branch->GetSingleWordInOperand(1);
const auto* default_info = GetBlockInfo(default_id);
clause_heads.push_back(default_info);
// Now the case clauses.
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
clause_heads.push_back(GetBlockInfo(case_target_id));
}
std::stable_sort(
clause_heads.begin(), clause_heads.end(),
[](const BlockInfo* lhs, const BlockInfo* rhs) { return lhs->pos < rhs->pos; });
// Remove duplicates
{
// Use read index r, and write index w.
// Invariant: w <= r;
size_t w = 0;
for (size_t r = 0; r < clause_heads.size(); ++r) {
if (clause_heads[r] != clause_heads[w]) {
++w; // Advance the write cursor.
}
clause_heads[w] = clause_heads[r];
}
// We know it's not empty because it always has at least a default clause.
TINT_ASSERT(Reader, !clause_heads.empty());
clause_heads.resize(w + 1);
}
// Push them on in reverse order.
const auto last_clause_index = clause_heads.size() - 1;
for (size_t i = last_clause_index;; --i) {
// Create a list of integer literals for the selector values leading to
// this case clause.
ast::CaseSelectorList selectors;
const auto* values_ptr = clause_heads[i]->case_values.get();
const bool has_selectors = (values_ptr && !values_ptr->empty());
if (has_selectors) {
std::vector<uint64_t> values(values_ptr->begin(), values_ptr->end());
std::stable_sort(values.begin(), values.end());
for (auto value : values) {
// The rest of this module can handle up to 64 bit switch values.
// The Tint AST handles 32-bit values.
const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
if (selector.type->IsUnsignedScalarOrVector()) {
selectors.emplace_back(create<ast::IntLiteralExpression>(
Source{}, value32, ast::IntLiteralExpression::Suffix::kU));
} else {
selectors.emplace_back(
create<ast::IntLiteralExpression>(Source{}, static_cast<int32_t>(value32),
ast::IntLiteralExpression::Suffix::kI));
}
}
}
// Where does this clause end?
const auto end_id =
(i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id : construct->end_id;
// Reserve the case clause slot in swch->cases, push the new statement block
// for the case, and fill the case clause once the block is generated.
auto case_idx = swch->cases.size();
swch->cases.emplace_back(nullptr);
PushNewStatementBlock(construct, end_id, [=](const ast::StatementList& stmts) {
auto* body = create<ast::BlockStatement>(Source{}, stmts);
swch->cases[case_idx] = create<ast::CaseStatement>(Source{}, selectors, body);
});
if ((default_info == clause_heads[i]) && has_selectors &&
construct->ContainsPos(default_info->pos)) {
// Generate a default clause with a just fallthrough.
auto* stmts = create<ast::BlockStatement>(
Source{}, ast::StatementList{
create<ast::FallthroughStatement>(Source{}),
});
auto* case_stmt = create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts);
swch->cases.emplace_back(case_stmt);
}
if (i == 0) {
break;
}
}
return success();
}
bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
auto* builder = AddStatementBuilder<LoopStatementBuilder>();
PushNewStatementBlock(construct, construct->end_id, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
return success();
}
bool FunctionEmitter::EmitContinuingStart(const Construct* construct) {
// A continue construct has the same depth as its associated loop
// construct. Start a continue construct.
auto* loop_candidate = LastStatement();
auto* loop = loop_candidate->As<LoopStatementBuilder>();
if (loop == nullptr) {
return Fail() << "internal error: starting continue construct, "
"expected loop on top of stack";
}
PushNewStatementBlock(construct, construct->end_id, [=](const ast::StatementList& stmts) {
loop->continuing = create<ast::BlockStatement>(Source{}, stmts);
});
return success();
}
bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
const auto& terminator = *(block_info.basic_block->terminator());
switch (terminator.opcode()) {
case SpvOpReturn:
AddStatement(create<ast::ReturnStatement>(Source{}));
return true;
case SpvOpReturnValue: {
auto value = MakeExpression(terminator.GetSingleWordInOperand(0));
if (!value) {
return false;
}
AddStatement(create<ast::ReturnStatement>(Source{}, value.expr));
}
return true;
case SpvOpKill:
// For now, assume SPIR-V OpKill has same semantics as WGSL discard.
// TODO(dneto): https://github.com/gpuweb/gpuweb/issues/676
AddStatement(create<ast::DiscardStatement>(Source{}));
return true;
case SpvOpUnreachable:
// Translate as if it's a return. This avoids the problem where WGSL
// requires a return statement at the end of the function body.
{
const auto* result_type = type_mgr_->GetType(function_.type_id());
if (result_type->AsVoid() != nullptr) {
AddStatement(create<ast::ReturnStatement>(Source{}));
} else {
auto* ast_type = parser_impl_.ConvertType(function_.type_id());
AddStatement(create<ast::ReturnStatement>(
Source{}, parser_impl_.MakeNullValue(ast_type)));
}
}
return true;
case SpvOpBranch: {
const auto dest_id = terminator.GetSingleWordInOperand(0);
AddStatement(MakeBranch(block_info, *GetBlockInfo(dest_id)));
return true;
}
case SpvOpBranchConditional: {
// If both destinations are the same, then do the same as we would
// for an unconditional branch (OpBranch).
const auto true_dest = terminator.GetSingleWordInOperand(1);
const auto false_dest = terminator.GetSingleWordInOperand(2);
if (true_dest == false_dest) {
// This is like an unconditional branch.
AddStatement(MakeBranch(block_info, *GetBlockInfo(true_dest)));
return true;
}
const EdgeKind true_kind = block_info.succ_edge.find(true_dest)->second;
const EdgeKind false_kind = block_info.succ_edge.find(false_dest)->second;
auto* const true_info = GetBlockInfo(true_dest);
auto* const false_info = GetBlockInfo(false_dest);
auto* cond = MakeExpression(terminator.GetSingleWordInOperand(0)).expr;
if (!cond) {
return false;
}
// We have two distinct destinations. But we only get here if this
// is a normal terminator; in particular the source block is *not* the
// start of an if-selection or a switch-selection. So at most one branch
// is a kForward, kCaseFallThrough, or kIfBreak.
// The fallthrough case is special because WGSL requires the fallthrough
// statement to be last in the case clause.
if (true_kind == EdgeKind::kCaseFallThrough) {
return EmitConditionalCaseFallThrough(block_info, cond, false_kind, *false_info,
true);
} else if (false_kind == EdgeKind::kCaseFallThrough) {
return EmitConditionalCaseFallThrough(block_info, cond, true_kind, *true_info,
false);
}
// At this point, at most one edge is kForward or kIfBreak.
// Emit an 'if' statement to express the *other* branch as a conditional
// break or continue. Either or both of these could be nullptr.
// (A nullptr is generated for kIfBreak, kForward, or kBack.)
// Also if one of the branches is an if-break out of an if-selection
// requiring a flow guard, then get that flow guard name too. It will
// come from at most one of these two branches.
std::string flow_guard;
auto* true_branch = MakeBranchDetailed(block_info, *true_info, false, &flow_guard);
auto* false_branch = MakeBranchDetailed(block_info, *false_info, false, &flow_guard);
AddStatement(MakeSimpleIf(cond, true_branch, false_branch));
if (!flow_guard.empty()) {
PushGuard(flow_guard, statements_stack_.back().GetEndId());
}
return true;
}
case SpvOpSwitch:
// An OpSelectionMerge must precede an OpSwitch. That is clarified
// in the resolution to Khronos-internal SPIR-V issue 115.
// A new enough version of the SPIR-V validator checks this case.
// But issue an error in this case, as a defensive measure.
return Fail() << "invalid structured control flow: found an OpSwitch "
"that is not preceded by an "
"OpSelectionMerge: "
<< terminator.PrettyPrint();
default:
break;
}
return success();
}
const ast::Statement* FunctionEmitter::MakeBranchDetailed(const BlockInfo& src_info,
const BlockInfo& dest_info,
bool forced,
std::string* flow_guard_name_ptr) const {
auto kind = src_info.succ_edge.find(dest_info.id)->second;
switch (kind) {
case EdgeKind::kBack:
// Nothing to do. The loop backedge is implicit.
break;
case EdgeKind::kSwitchBreak: {
if (forced) {
return create<ast::BreakStatement>(Source{});
}
// Unless forced, don't bother with a break at the end of a case/default
// clause.
const auto header = dest_info.header_for_merge;
TINT_ASSERT(Reader, header != 0);
const auto* exiting_construct = GetBlockInfo(header)->construct;
TINT_ASSERT(Reader, exiting_construct->kind == Construct::kSwitchSelection);
const auto candidate_next_case_pos = src_info.pos + 1;
// Leaving the last block from the last case?
if (candidate_next_case_pos == dest_info.pos) {
// No break needed.
return nullptr;
}
// Leaving the last block from not-the-last-case?
if (exiting_construct->ContainsPos(candidate_next_case_pos)) {
const auto* candidate_next_case =
GetBlockInfo(block_order_[candidate_next_case_pos]);
if (candidate_next_case->case_head_for == exiting_construct ||
candidate_next_case->default_head_for == exiting_construct) {
// No break needed.
return nullptr;
}
}
// We need a break.
return create<ast::BreakStatement>(Source{});
}
case EdgeKind::kLoopBreak:
return create<ast::BreakStatement>(Source{});
case EdgeKind::kLoopContinue:
// An unconditional continue to the next block is redundant and ugly.
// Skip it in that case.
if (dest_info.pos == 1 + src_info.pos) {
break;
}
// Otherwise, emit a regular continue statement.
return create<ast::ContinueStatement>(Source{});
case EdgeKind::kIfBreak: {
const auto& flow_guard = GetBlockInfo(dest_info.header_for_merge)->flow_guard_name;
if (!flow_guard.empty()) {
if (flow_guard_name_ptr != nullptr) {
*flow_guard_name_ptr = flow_guard;
}
// Signal an exit from the branch.
return create<ast::AssignmentStatement>(
Source{},
create<ast::IdentifierExpression>(Source{},
builder_.Symbols().Register(flow_guard)),
MakeFalse(Source{}));
}
// For an unconditional branch, the break out to an if-selection
// merge block is implicit.
break;
}
case EdgeKind::kCaseFallThrough:
return create<ast::FallthroughStatement>(Source{});
case EdgeKind::kForward:
// Unconditional forward branch is implicit.
break;
}
return nullptr;
}
const ast::Statement* FunctionEmitter::MakeSimpleIf(const ast::Expression* condition,
const ast::Statement* then_stmt,
const ast::Statement* else_stmt) const {
if ((then_stmt == nullptr) && (else_stmt == nullptr)) {
return nullptr;
}
ast::StatementList if_stmts;
if (then_stmt != nullptr) {
if_stmts.emplace_back(then_stmt);
}
auto* if_block = create<ast::BlockStatement>(Source{}, if_stmts);
const ast::Statement* else_block = nullptr;
if (else_stmt) {
else_block = create<ast::BlockStatement>(ast::StatementList{else_stmt});
}
auto* if_stmt = create<ast::IfStatement>(Source{}, condition, if_block, else_block);
return if_stmt;
}
bool FunctionEmitter::EmitConditionalCaseFallThrough(const BlockInfo& src_info,
const ast::Expression* cond,
EdgeKind other_edge_kind,
const BlockInfo& other_dest,
bool fall_through_is_true_branch) {
// In WGSL, the fallthrough statement must come last in the case clause.
// So we'll emit an if statement for the other branch, and then emit
// the fallthrough.
// We have two distinct destinations. But we only get here if this
// is a normal terminator; in particular the source block is *not* the
// start of an if-selection. So at most one branch is a kForward or
// kCaseFallThrough.
if (other_edge_kind == EdgeKind::kForward) {
return Fail() << "internal error: normal terminator OpBranchConditional has "
"both forward and fallthrough edges";
}
if (other_edge_kind == EdgeKind::kIfBreak) {
return Fail() << "internal error: normal terminator OpBranchConditional has "
"both IfBreak and fallthrough edges. Violates nesting rule";
}
if (other_edge_kind == EdgeKind::kBack) {
return Fail() << "internal error: normal terminator OpBranchConditional has "
"both backedge and fallthrough edges. Violates nesting rule";
}
auto* other_branch = MakeForcedBranch(src_info, other_dest);
if (other_branch == nullptr) {
return Fail() << "internal error: expected a branch for edge-kind " << int(other_edge_kind);
}
if (fall_through_is_true_branch) {
AddStatement(MakeSimpleIf(cond, nullptr, other_branch));
} else {
AddStatement(MakeSimpleIf(cond, other_branch, nullptr));
}
AddStatement(create<ast::FallthroughStatement>(Source{}));
return success();
}
bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
bool* already_emitted) {
if (*already_emitted) {
// Only emit this part of the basic block once.
return true;
}
// Returns the given list of local definition IDs, sorted by their index.
auto sorted_by_index = [this](const std::vector<uint32_t>& ids) {
auto sorted = ids;
std::stable_sort(sorted.begin(), sorted.end(),
[this](const uint32_t lhs, const uint32_t rhs) {
return GetDefInfo(lhs)->index < GetDefInfo(rhs)->index;
});
return sorted;
};
// Emit declarations of hoisted variables, in index order.
for (auto id : sorted_by_index(block_info.hoisted_ids)) {
const auto* def_inst = def_use_mgr_->GetDef(id);
TINT_ASSERT(Reader, def_inst);
auto* storage_type = RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
AddStatement(create<ast::VariableDeclStatement>(
Source{}, parser_impl_.MakeVariable(id, ast::StorageClass::kNone, storage_type, false,
false, nullptr, ast::AttributeList{})));
auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone);
identifier_types_.emplace(id, type);
}
// Emit declarations of phi state variables, in index order.
for (auto id : sorted_by_index(block_info.phis_needing_state_vars)) {
const auto* def_inst = def_use_mgr_->GetDef(id);
TINT_ASSERT(Reader, def_inst);
const auto phi_var_name = GetDefInfo(id)->phi_var;
TINT_ASSERT(Reader, !phi_var_name.empty());
auto* var = builder_.Var(phi_var_name,
parser_impl_.ConvertType(def_inst->type_id())->Build(builder_));
AddStatement(create<ast::VariableDeclStatement>(Source{}, var));
}
// Emit regular statements.
const spvtools::opt::BasicBlock& bb = *(block_info.basic_block);
const auto* terminator = bb.terminator();
const auto* merge = bb.GetMergeInst(); // Might be nullptr
for (auto& inst : bb) {
if (&inst == terminator || &inst == merge || inst.opcode() == SpvOpLabel ||
inst.opcode() == SpvOpVariable) {
continue;
}
if (!EmitStatement(inst)) {
return false;
}
}
// Emit assignments to carry values to phi nodes in potential destinations.
// Do it in index order.
if (!block_info.phi_assignments.empty()) {
auto sorted = block_info.phi_assignments;
std::stable_sort(
sorted.begin(), sorted.end(),
[this](const BlockInfo::PhiAssignment& lhs, const BlockInfo::PhiAssignment& rhs) {
return GetDefInfo(lhs.phi_id)->index < GetDefInfo(rhs.phi_id)->index;
});
for (auto assignment : block_info.phi_assignments) {
const auto var_name = GetDefInfo(assignment.phi_id)->phi_var;
auto expr = MakeExpression(assignment.value);
if (!expr) {
return false;
}
AddStatement(create<ast::AssignmentStatement>(
Source{},
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(var_name)),
expr.expr));
}
}
*already_emitted = true;
return true;
}
bool FunctionEmitter::EmitConstDefinition(const spvtools::opt::Instruction& inst,
TypedExpression expr) {
if (!expr) {
return false;
}
// Do not generate pointers that we want to sink.
if (GetDefInfo(inst.result_id())->skip == SkipReason::kSinkPointerIntoUse) {
return true;
}
expr = AddressOfIfNeeded(expr, &inst);
auto* ast_const =
parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, expr.type, true,
false, expr.expr, ast::AttributeList{});
if (!ast_const) {
return false;
}
AddStatement(create<ast::VariableDeclStatement>(Source{}, ast_const));
identifier_types_.emplace(inst.result_id(), expr.type);
return success();
}
bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar(const spvtools::opt::Instruction& inst,
TypedExpression expr) {
return WriteIfHoistedVar(inst, expr) || EmitConstDefinition(inst, expr);
}
bool FunctionEmitter::WriteIfHoistedVar(const spvtools::opt::Instruction& inst,
TypedExpression expr) {
const auto result_id = inst.result_id();
const auto* def_info = GetDefInfo(result_id);
if (def_info && def_info->requires_hoisted_def) {
auto name = namer_.Name(result_id);
// Emit an assignment of the expression to the hoisted variable.
AddStatement(create<ast::AssignmentStatement>(
Source{},
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name)),
expr.expr));
return true;
}
return false;
}
bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
if (failed()) {
return false;
}
const auto result_id = inst.result_id();
const auto type_id = inst.type_id();
if (type_id != 0) {
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
if (type_id == builtin_position_info.struct_type_id) {
return Fail() << "operations producing a per-vertex structure are not "
"supported: "
<< inst.PrettyPrint();
}
if (type_id == builtin_position_info.pointer_type_id) {
return Fail() << "operations producing a pointer to a per-vertex "
"structure are not "
"supported: "
<< inst.PrettyPrint();
}
}
// Handle combinatorial instructions.
const auto* def_info = GetDefInfo(result_id);
if (def_info) {
TypedExpression combinatorial_expr;
if (def_info->skip == SkipReason::kDontSkip) {
combinatorial_expr = MaybeEmitCombinatorialValue(inst);
if (!success()) {
return false;
}
}
// An access chain or OpCopyObject can generate a skip.
if (def_info->skip != SkipReason::kDontSkip) {
return true;
}
if (combinatorial_expr.expr != nullptr) {
if (def_info->requires_hoisted_def || def_info->requires_named_const_def ||
def_info->num_uses != 1) {
// Generate a const definition or an assignment to a hoisted definition
// now and later use the const or variable name at the uses of this
// value.
return EmitConstDefOrWriteToHoistedVar(inst, combinatorial_expr);
}
// It is harmless to defer emitting the expression until it's used.
// Any supporting statements have already been emitted.
singly_used_values_.insert(std::make_pair(result_id, combinatorial_expr));
return success();
}
}
if (failed()) {
return false;
}
if (IsImageQuery(inst.opcode())) {
return EmitImageQuery(inst);
}
if (IsSampledImageAccess(inst.opcode()) || IsRawImageAccess(inst.opcode())) {
return EmitImageAccess(inst);
}
switch (inst.opcode()) {
case SpvOpNop:
return true;
case SpvOpStore: {
auto ptr_id = inst.GetSingleWordInOperand(0);
const auto value_id = inst.GetSingleWordInOperand(1);
const auto ptr_type_id = def_use_mgr_->GetDef(ptr_id)->type_id();
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
if (ptr_type_id == builtin_position_info.pointer_type_id) {
return Fail() << "storing to the whole per-vertex structure is not supported: "
<< inst.PrettyPrint();
}
TypedExpression rhs = MakeExpression(value_id);
if (!rhs) {
return false;
}
TypedExpression lhs;
// Handle exceptional cases
switch (GetSkipReason(ptr_id)) {
case SkipReason::kPointSizeBuiltinPointer:
if (IsFloatOne(value_id)) {
// Don't store to PointSize
return true;
}
return Fail() << "cannot store a value other than constant 1.0 to "
"PointSize builtin: "
<< inst.PrettyPrint();
case SkipReason::kSampleMaskOutBuiltinPointer:
lhs = MakeExpression(sample_mask_out_id);
if (lhs.type->Is<Pointer>()) {
// LHS of an assignment must be a reference type.
// Convert the LHS to a reference by dereferencing it.
lhs = Dereference(lhs);
}
// The private variable is an array whose element type is already of
// the same type as the value being stored into it. Form the
// reference into the first element.
lhs.expr = create<ast::IndexAccessorExpression>(
Source{}, lhs.expr, parser_impl_.MakeNullValue(ty_.I32()));
if (auto* ref = lhs.type->As<Reference>()) {
lhs.type = ref->type;
}
if (auto* arr = lhs.type->As<Array>()) {
lhs.type = arr->type;
}
TINT_ASSERT(Reader, lhs.type);
break;
default:
break;
}
// Handle an ordinary store as an assignment.
if (!lhs) {
lhs = MakeExpression(ptr_id);
}
if (!lhs) {
return false;
}
if (lhs.type->Is<Pointer>()) {
// LHS of an assignment must be a reference type.
// Convert the LHS to a reference by dereferencing it.
lhs = Dereference(lhs);
}
AddStatement(create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
return success();
}
case SpvOpLoad: {
// Memory accesses must be issued in SPIR-V program order.
// So represent a load by a new const definition.
const auto ptr_id = inst.GetSingleWordInOperand(0);
const auto skip_reason = GetSkipReason(ptr_id);
switch (skip_reason) {
case SkipReason::kPointSizeBuiltinPointer:
GetDefInfo(inst.result_id())->skip = SkipReason::kPointSizeBuiltinValue;
return true;
case SkipReason::kSampleMaskInBuiltinPointer: {
auto name = namer_.Name(sample_mask_in_id);
const ast::Expression* id_expr = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(name));
// SampleMask is an array in Vulkan SPIR-V. Always access the first
// element.
id_expr = create<ast::IndexAccessorExpression>(
Source{}, id_expr, parser_impl_.MakeNullValue(ty_.I32()));
auto* loaded_type = parser_impl_.ConvertType(inst.type_id());
if (!loaded_type->IsIntegerScalar()) {
return Fail() << "loading the whole SampleMask input array is not "
"supported: "
<< inst.PrettyPrint();
}
auto expr = TypedExpression{loaded_type, id_expr};
return EmitConstDefinition(inst, expr);
}
default:
break;
}
auto expr = MakeExpression(ptr_id);
if (!expr) {
return false;
}
// The load result type is the storage type of its operand.
if (expr.type->Is<Pointer>()) {
expr = Dereference(expr);
} else if (auto* ref = expr.type->As<Reference>()) {
expr.type = ref->type;
} else {
Fail() << "OpLoad expression is not a pointer or reference";
return false;
}
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpCopyMemory: {
// Generate an assignment.
auto lhs = MakeOperand(inst, 0);
auto rhs = MakeOperand(inst, 1);
// Ignore any potential memory operands. Currently they are all for
// concepts not in WGSL:
// Volatile
// Aligned
// Nontemporal
// MakePointerAvailable ; Vulkan memory model
// MakePointerVisible ; Vulkan memory model
// NonPrivatePointer ; Vulkan memory model
if (!success()) {
return false;
}
// LHS and RHS pointers must be reference types in WGSL.
if (lhs.type->Is<Pointer>()) {
lhs = Dereference(lhs);
}
if (rhs.type->Is<Pointer>()) {
rhs = Dereference(rhs);
}
AddStatement(create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
return success();
}
case SpvOpCopyObject: {
// Arguably, OpCopyObject is purely combinatorial. On the other hand,
// it exists to make a new name for something. So we choose to make
// a new named constant definition.
auto value_id = inst.GetSingleWordInOperand(0);
const auto skip = GetSkipReason(value_id);
if (skip != SkipReason::kDontSkip) {
GetDefInfo(inst.result_id())->skip = skip;
GetDefInfo(inst.result_id())->sink_pointer_source_expr =
GetDefInfo(value_id)->sink_pointer_source_expr;
return true;
}
auto expr = AddressOfIfNeeded(MakeExpression(value_id), &inst);
if (!expr) {
return false;
}
expr.type = RemapStorageClass(expr.type, result_id);
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpPhi: {
// Emit a read from the associated state variable.
TypedExpression expr{parser_impl_.ConvertType(inst.type_id()),
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(def_info->phi_var))};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpOuterProduct:
// Synthesize an outer product expression in its own statement.
return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
case SpvOpVectorInsertDynamic:
// Synthesize a vector insertion in its own statements.
return MakeVectorInsertDynamic(inst);
case SpvOpCompositeInsert:
// Synthesize a composite insertion in its own statements.
return MakeCompositeInsert(inst);
case SpvOpFunctionCall:
return EmitFunctionCall(inst);
case SpvOpControlBarrier:
return EmitControlBarrier(inst);
case SpvOpExtInst:
if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
return true;
}
break;
case SpvOpIAddCarry:
case SpvOpISubBorrow:
case SpvOpUMulExtended:
case SpvOpSMulExtended:
return Fail() << "extended arithmetic is not finalized for WGSL: "
"https://github.com/gpuweb/gpuweb/issues/1565: "
<< inst.PrettyPrint();
default:
break;
}
return Fail() << "unhandled instruction with opcode " << inst.opcode() << ": "
<< inst.PrettyPrint();
}
TypedExpression FunctionEmitter::MakeOperand(const spvtools::opt::Instruction& inst,
uint32_t operand_index) {
auto expr = MakeExpression(inst.GetSingleWordInOperand(operand_index));
if (!expr) {
return {};
}
return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
}
TypedExpression FunctionEmitter::InferFunctionStorageClass(TypedExpression expr) {
TypedExpression result(expr);
if (const auto* ref = expr.type->UnwrapAlias()->As<Reference>()) {
if (ref->storage_class == ast::StorageClass::kNone) {
expr.type = ty_.Reference(ref->type, ast::StorageClass::kFunction);
}
} else if (const auto* ptr = expr.type->UnwrapAlias()->As<Pointer>()) {
if (ptr->storage_class == ast::StorageClass::kNone) {
expr.type = ty_.Pointer(ptr->type, ast::StorageClass::kFunction);
}
}
return expr;
}
TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
const spvtools::opt::Instruction& inst) {
if (inst.result_id() == 0) {
return {};
}
const auto opcode = inst.opcode();
const Type* ast_type = nullptr;
if (inst.type_id()) {
ast_type = parser_impl_.ConvertType(inst.type_id());
if (!ast_type) {
Fail() << "couldn't convert result type for: " << inst.PrettyPrint();
return {};
}
}
auto binary_op = ConvertBinaryOp(opcode);
if (binary_op != ast::BinaryOp::kNone) {
auto arg0 = MakeOperand(inst, 0);
auto arg1 =
parser_impl_.RectifySecondOperandSignedness(inst, arg0.type, MakeOperand(inst, 1));
if (!arg0 || !arg1) {
return {};
}
auto* binary_expr =
create<ast::BinaryExpression>(Source{}, binary_op, arg0.expr, arg1.expr);
TypedExpression result{ast_type, binary_expr};
return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
auto unary_op = ast::UnaryOp::kNegation;
if (GetUnaryOp(opcode, &unary_op)) {
auto arg0 = MakeOperand(inst, 0);
auto* unary_expr = create<ast::UnaryOpExpression>(Source{}, unary_op, arg0.expr);
TypedExpression result{ast_type, unary_expr};
return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
const char* unary_builtin_name = GetUnaryBuiltInFunctionName(opcode);
if (unary_builtin_name != nullptr) {
ast::ExpressionList params;
params.emplace_back(MakeOperand(inst, 0).expr);
return {ast_type, create<ast::CallExpression>(
Source{},
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register(unary_builtin_name)),
std::move(params))};
}
const auto builtin = GetBuiltin(opcode);
if (builtin != sem::BuiltinType::kNone) {
return MakeBuiltinCall(inst);
}
if (opcode == SpvOpFMod) {
return MakeFMod(inst);
}
if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
return MakeAccessChain(inst);
}
if (opcode == SpvOpBitcast) {
return {ast_type, create<ast::BitcastExpression>(Source{}, ast_type->Build(builder_),
MakeOperand(inst, 0).expr)};
}
if (opcode == SpvOpShiftLeftLogical || opcode == SpvOpShiftRightLogical ||
opcode == SpvOpShiftRightArithmetic) {
auto arg0 = MakeOperand(inst, 0);
// The second operand must be unsigned. It's ok to wrap the shift amount
// since the shift is modulo the bit width of the first operand.
auto arg1 = parser_impl_.AsUnsigned(MakeOperand(inst, 1));
switch (opcode) {
case SpvOpShiftLeftLogical:
binary_op = ast::BinaryOp::kShiftLeft;
break;
case SpvOpShiftRightLogical:
arg0 = parser_impl_.AsUnsigned(arg0);
binary_op = ast::BinaryOp::kShiftRight;
break;
case SpvOpShiftRightArithmetic:
arg0 = parser_impl_.AsSigned(arg0);
binary_op = ast::BinaryOp::kShiftRight;
break;
default:
break;
}
TypedExpression result{
ast_type, create<ast::BinaryExpression>(Source{}, binary_op, arg0.expr, arg1.expr)};
return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
auto negated_op = NegatedFloatCompare(opcode);
if (negated_op != ast::BinaryOp::kNone) {
auto arg0 = MakeOperand(inst, 0);
auto arg1 = MakeOperand(inst, 1);
auto* binary_expr =
create<ast::BinaryExpression>(Source{}, negated_op, arg0.expr, arg1.expr);
auto* negated_expr =
create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kNot, binary_expr);
return {ast_type, negated_expr};
}
if (opcode == SpvOpExtInst) {
if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
// Ignore it but don't error out.
return {};
}
if (!parser_impl_.IsGlslExtendedInstruction(inst)) {
Fail() << "unhandled extended instruction import with ID "
<< inst.GetSingleWordInOperand(0);
return {};
}
return EmitGlslStd450ExtInst(inst);
}
if (opcode == SpvOpCompositeConstruct) {
ast::ExpressionList operands;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
operands.emplace_back(MakeOperand(inst, iarg).expr);
}
return {ast_type,
builder_.Construct(Source{}, ast_type->Build(builder_), std::move(operands))};
}
if (opcode == SpvOpCompositeExtract) {
return MakeCompositeExtract(inst);
}
if (opcode == SpvOpVectorShuffle) {
return MakeVectorShuffle(inst);
}
if (opcode == SpvOpVectorExtractDynamic) {
return {ast_type, create<ast::IndexAccessorExpression>(Source{}, MakeOperand(inst, 0).expr,
MakeOperand(inst, 1).expr)};
}
if (opcode == SpvOpConvertSToF || opcode == SpvOpConvertUToF || opcode == SpvOpConvertFToS ||
opcode == SpvOpConvertFToU) {
return MakeNumericConversion(inst);
}
if (opcode == SpvOpUndef) {
// Replace undef with the null value.
return parser_impl_.MakeNullExpression(ast_type);
}
if (opcode == SpvOpSelect) {
return MakeSimpleSelect(inst);
}
if (opcode == SpvOpArrayLength) {
return MakeArrayLength(inst);
}
// builtin readonly function
// glsl.std.450 readonly function
// Instructions:
// OpSatConvertSToU // Only in Kernel (OpenCL), not in WebGPU
// OpSatConvertUToS // Only in Kernel (OpenCL), not in WebGPU
// OpUConvert // Only needed when multiple widths supported
// OpSConvert // Only needed when multiple widths supported
// OpFConvert // Only needed when multiple widths supported
// OpConvertPtrToU // Not in WebGPU
// OpConvertUToPtr // Not in WebGPU
// OpPtrCastToGeneric // Not in Vulkan
// OpGenericCastToPtr // Not in Vulkan
// OpGenericCastToPtrExplicit // Not in Vulkan
return {};
}
TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(const spvtools::opt::Instruction& inst) {
const auto ext_opcode = inst.GetSingleWordInOperand(1);
if (ext_opcode == GLSLstd450Ldexp) {
// WGSL requires the second argument to be signed.
// Use a type constructor to convert it, which is the same as a bitcast.
// If the value would go from very large positive to negative, then the
// original result would have been infinity. And since WGSL
// implementations may assume that infinities are not present, then we
// don't have to worry about that case.
auto e1 = MakeOperand(inst, 2);
auto e2 = ToSignedIfUnsigned(MakeOperand(inst, 3));
return {e1.type, builder_.Call(Source{}, "ldexp", ast::ExpressionList{e1.expr, e2.expr})};
}
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (result_type->IsScalar()) {
// Some GLSLstd450 builtins have scalar forms not supported by WGSL.
// Emulate them.
switch (ext_opcode) {
case GLSLstd450Normalize:
// WGSL does not have scalar form of the normalize builtin.
// The answer would be 1 anyway, so return that directly.
return {ty_.F32(), builder_.Expr(1_f)};
case GLSLstd450FaceForward: {
// If dot(Nref, Incident) < 0, the result is Normal, otherwise -Normal.
// Also: select(-normal,normal, Incident*Nref < 0)
// (The dot product of scalars is their product.)
// Use a multiply instead of comparing floating point signs. It should
// be among the fastest operations on a GPU.
auto normal = MakeOperand(inst, 2);
auto incident = MakeOperand(inst, 3);
auto nref = MakeOperand(inst, 4);
TINT_ASSERT(Reader, normal.type->Is<F32>());
TINT_ASSERT(Reader, incident.type->Is<F32>());
TINT_ASSERT(Reader, nref.type->Is<F32>());
return {ty_.F32(),
builder_.Call(
Source{}, "select",
ast::ExpressionList{create<ast::UnaryOpExpression>(
Source{}, ast::UnaryOp::kNegation, normal.expr),
normal.expr,
create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kLessThan,
builder_.Mul({}, incident.expr, nref.expr),
builder_.Expr(0_f))})};
}
case GLSLstd450Reflect: {
// Compute Incident - 2 * Normal * Normal * Incident
auto incident = MakeOperand(inst, 2);
auto normal = MakeOperand(inst, 3);
TINT_ASSERT(Reader, incident.type->Is<F32>());
TINT_ASSERT(Reader, normal.type->Is<F32>());
return {
ty_.F32(),
builder_.Sub(
incident.expr,
builder_.Mul(2_f, builder_.Mul(normal.expr,
builder_.Mul(normal.expr, incident.expr))))};
}
case GLSLstd450Refract: {
// It's a complicated expression. Compute it in two dimensions, but
// with a 0-valued y component in both the incident and normal vectors,
// then take the x component of that result.
auto incident = MakeOperand(inst, 2);
auto normal = MakeOperand(inst, 3);
auto eta = MakeOperand(inst, 4);
TINT_ASSERT(Reader, incident.type->Is<F32>());
TINT_ASSERT(Reader, normal.type->Is<F32>());
TINT_ASSERT(Reader, eta.type->Is<F32>());
if (!success()) {
return {};
}
const Type* f32 = eta.type;
return {f32, builder_.MemberAccessor(
builder_.Call(
Source{}, "refract",
ast::ExpressionList{
builder_.vec2<tint::f32>(incident.expr, 0_f),
builder_.vec2<tint::f32>(normal.expr, 0_f), eta.expr}),
"x")};
}
default:
break;
}
}
const auto name = GetGlslStd450FuncName(ext_opcode);
if (name.empty()) {
Fail() << "unhandled GLSL.std.450 instruction " << ext_opcode;
return {};
}
auto* func = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
ast::ExpressionList operands;
const Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs.
for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg);
if (first_operand_type == nullptr) {
first_operand_type = operand.type;
}
operands.emplace_back(operand.expr);
}
auto* call = create<ast::CallExpression>(Source{}, func, std::move(operands));
TypedExpression call_expr{result_type, call};
return parser_impl_.RectifyForcedResultType(call_expr, inst, first_operand_type);
}
ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
if (i >= kMaxVectorLen) {
Fail() << "vector component index is larger than " << kMaxVectorLen - 1 << ": " << i;
return nullptr;
}
const char* names[] = {"x", "y", "z", "w"};
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(names[i & 3]));
}
ast::IdentifierExpression* FunctionEmitter::PrefixSwizzle(uint32_t n) {
switch (n) {
case 1:
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register("x"));
case 2:
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register("xy"));
case 3:
return create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register("xyz"));
default:
break;
}
Fail() << "invalid swizzle prefix count: " << n;
return nullptr;
}
TypedExpression FunctionEmitter::MakeFMod(const spvtools::opt::Instruction& inst) {
auto x = MakeOperand(inst, 0);
auto y = MakeOperand(inst, 1);
if (!x || !y) {
return {};
}
// Emulated with: x - y * floor(x / y)
auto* div = builder_.Div(x.expr, y.expr);
auto* floor = builder_.Call("floor", div);
auto* y_floor = builder_.Mul(y.expr, floor);
auto* res = builder_.Sub(x.expr, y_floor);
return {x.type, res};
}
TypedExpression FunctionEmitter::MakeAccessChain(const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() < 1) {
// Binary parsing will fail on this anyway.
Fail() << "invalid access chain: has no input operands";
return {};
}
const auto base_id = inst.GetSingleWordInOperand(0);
const auto base_skip = GetSkipReason(base_id);
if (base_skip != SkipReason::kDontSkip) {
// This can occur for AccessChain with no indices.
GetDefInfo(inst.result_id())->skip = base_skip;
GetDefInfo(inst.result_id())->sink_pointer_source_expr =
GetDefInfo(base_id)->sink_pointer_source_expr;
return {};
}
auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
uint32_t first_index = 1;
const auto num_in_operands = inst.NumInOperands();
bool sink_pointer = false;
TypedExpression current_expr;
// If the variable was originally gl_PerVertex, then in the AST we
// have instead emitted a gl_Position variable.
// If computing the pointer to the Position builtin, then emit the
// pointer to the generated gl_Position variable.
// If computing the pointer to the PointSize builtin, then mark the
// result as skippable due to being the point-size pointer.
// If computing the pointer to the ClipDistance or CullDistance builtins,
// then error out.
{
const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
if (base_id == builtin_position_info.per_vertex_var_id) {
// We only support the Position member.
const auto* member_index_inst =
def_use_mgr_->GetDef(inst.GetSingleWordInOperand(first_index));
if (member_index_inst == nullptr) {
Fail() << "first index of access chain does not reference an instruction: "
<< inst.PrettyPrint();
return {};
}
const auto* member_index_const = constant_mgr_->GetConstantFromInst(member_index_inst);
if (member_index_const == nullptr) {
Fail() << "first index of access chain into per-vertex structure is "
"not a constant: "
<< inst.PrettyPrint();
return {};
}
const auto* member_index_const_int = member_index_const->AsIntConstant();
if (member_index_const_int == nullptr) {
Fail() << "first index of access chain into per-vertex structure is "
"not a constant integer: "
<< inst.PrettyPrint();
return {};
}
const auto member_index_value = member_index_const_int->GetZeroExtendedValue();
if (member_index_value != builtin_position_info.position_member_index) {
if (member_index_value == builtin_position_info.pointsize_member_index) {
if (auto* def_info = GetDefInfo(inst.result_id())) {
def_info->skip = SkipReason::kPointSizeBuiltinPointer;
return {};
}
} else {
// TODO(dneto): Handle ClipDistance and CullDistance
Fail() << "accessing per-vertex member " << member_index_value
<< " is not supported. Only Position is supported, and "
"PointSize is ignored";
return {};
}
}
// Skip past the member index that gets us to Position.
first_index = first_index + 1;
// Replace the gl_PerVertex reference with the gl_Position reference
ptr_ty_id = builtin_position_info.position_member_pointer_type_id;
auto name = namer_.Name(base_id);
current_expr.expr =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
current_expr.type = parser_impl_.ConvertType(ptr_ty_id, PtrAs::Ref);
}
}
// A SPIR-V access chain is a single instruction with multiple indices
// walking down into composites. The Tint AST represents this as
// ever-deeper nested indexing expressions. Start off with an expression
// for the base, and then bury that inside nested indexing expressions.
if (!current_expr) {
current_expr = InferFunctionStorageClass(MakeOperand(inst, 0));
if (current_expr.type->Is<Pointer>()) {
current_expr = Dereference(current_expr);
}
}
const auto constants = constant_mgr_->GetOperandConstants(&inst);
const auto* ptr_type_inst = def_use_mgr_->GetDef(ptr_ty_id);
if (!ptr_type_inst || (ptr_type_inst->opcode() != SpvOpTypePointer)) {
Fail() << "Access chain %" << inst.result_id() << " base pointer is not of pointer type";
return {};
}
SpvStorageClass storage_class =
static_cast<SpvStorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
// Build up a nested expression for the access chain by walking down the type
// hierarchy, maintaining |pointee_type_id| as the SPIR-V ID of the type of
// the object pointed to after processing the previous indices.
for (uint32_t index = first_index; index < num_in_operands; ++index) {
const auto* index_const = constants[index] ? constants[index]->AsIntConstant() : nullptr;
const int64_t index_const_val = index_const ? index_const->GetSignExtendedValue() : 0;
const ast::Expression* next_expr = nullptr;
const auto* pointee_type_inst = def_use_mgr_->GetDef(pointee_type_id);
if (!pointee_type_inst) {
Fail() << "pointee type %" << pointee_type_id << " is invalid after following "
<< (index - first_index) << " indices: " << inst.PrettyPrint();
return {};
}
switch (pointee_type_inst->opcode()) {
case SpvOpTypeVector:
if (index_const) {
// Try generating a MemberAccessor expression
const auto num_elems = pointee_type_inst->GetSingleWordInOperand(1);
if (index_const_val < 0 || num_elems <= index_const_val) {
Fail() << "Access chain %" << inst.result_id() << " index %"
<< inst.GetSingleWordInOperand(index) << " value " << index_const_val
<< " is out of bounds for vector of " << num_elems << " elements";
return {};
}
if (uint64_t(index_const_val) >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_const_val
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
}
next_expr = create<ast::MemberAccessorExpression>(
Source{}, current_expr.expr, Swizzle(uint32_t(index_const_val)));
} else {
// Non-constant index. Use array syntax
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
}
// All vector components are the same type.
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
// Sink pointers to vector components.
sink_pointer = true;
break;
case SpvOpTypeMatrix:
// Use array syntax.
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
// All matrix components are the same type.
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeArray:
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeRuntimeArray:
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
MakeOperand(inst, index).expr);
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeStruct: {
if (!index_const) {
Fail() << "Access chain %" << inst.result_id() << " index %"
<< inst.GetSingleWordInOperand(index)
<< " is a non-constant index into a structure %" << pointee_type_id;
return {};
}
const auto num_members = pointee_type_inst->NumInOperands();
if ((index_const_val < 0) || num_members <= uint64_t(index_const_val)) {
Fail() << "Access chain %" << inst.result_id() << " index value "
<< index_const_val << " is out of bounds for structure %"
<< pointee_type_id << " having " << num_members << " members";
return {};
}
auto name = namer_.GetMemberName(pointee_type_id, uint32_t(index_const_val));
auto* member_access =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
next_expr = create<ast::MemberAccessorExpression>(Source{}, current_expr.expr,
member_access);
pointee_type_id = pointee_type_inst->GetSingleWordInOperand(
static_cast<uint32_t>(index_const_val));
break;
}
default:
Fail() << "Access chain with unknown or invalid pointee type %" << pointee_type_id
<< ": " << pointee_type_inst->PrettyPrint();
return {};
}
const auto pointer_type_id = type_mgr_->FindPointerToType(pointee_type_id, storage_class);
auto* type = parser_impl_.ConvertType(pointer_type_id, PtrAs::Ref);
TINT_ASSERT(Reader, type && type->Is<Reference>());
current_expr = TypedExpression{type, next_expr};
}
if (sink_pointer) {
// Capture the reference so that we can sink it into the point of use.
GetDefInfo(inst.result_id())->skip = SkipReason::kSinkPointerIntoUse;
GetDefInfo(inst.result_id())->sink_pointer_source_expr = current_expr;
}
return current_expr;
}
TypedExpression FunctionEmitter::MakeCompositeExtract(const spvtools::opt::Instruction& inst) {
// This is structurally similar to creating an access chain, but
// the SPIR-V instruction has literal indices instead of IDs for indices.
auto composite_index = 0;
auto first_index_position = 1;
TypedExpression current_expr(MakeOperand(inst, composite_index));
if (!current_expr) {
return {};
}
const auto composite_id = inst.GetSingleWordInOperand(composite_index);
auto current_type_id = def_use_mgr_->GetDef(composite_id)->type_id();
return MakeCompositeValueDecomposition(inst, current_expr, current_type_id,
first_index_position);
}
TypedExpression FunctionEmitter::MakeCompositeValueDecomposition(
const spvtools::opt::Instruction& inst,
TypedExpression composite,
uint32_t composite_type_id,
int index_start) {
// This is structurally similar to creating an access chain, but
// the SPIR-V instruction has literal indices instead of IDs for indices.
// A SPIR-V composite extract is a single instruction with multiple
// literal indices walking down into composites.
// A SPIR-V composite insert is similar but also tells you what component
// to inject. This function is responsible for the the walking-into part
// of composite-insert.
//
// The Tint AST represents this as ever-deeper nested indexing expressions.
// Start off with an expression for the composite, and then bury that inside
// nested indexing expressions.
auto current_expr = composite;
auto current_type_id = composite_type_id;
auto make_index = [this](uint32_t literal) {
return create<ast::IntLiteralExpression>(Source{}, literal,
ast::IntLiteralExpression::Suffix::kU);
};
// Build up a nested expression for the decomposition by walking down the type
// hierarchy, maintaining |current_type_id| as the SPIR-V ID of the type of
// the object pointed to after processing the previous indices.
const auto num_in_operands = inst.NumInOperands();
for (uint32_t index = index_start; index < num_in_operands; ++index) {
const uint32_t index_val = inst.GetSingleWordInOperand(index);
const auto* current_type_inst = def_use_mgr_->GetDef(current_type_id);
if (!current_type_inst) {
Fail() << "composite type %" << current_type_id << " is invalid after following "
<< (index - index_start) << " indices: " << inst.PrettyPrint();
return {};
}
const char* operation_name = nullptr;
switch (inst.opcode()) {
case SpvOpCompositeExtract:
operation_name = "OpCompositeExtract";
break;
case SpvOpCompositeInsert:
operation_name = "OpCompositeInsert";
break;
default:
Fail() << "internal error: unhandled " << inst.PrettyPrint();
return {};
}
const ast::Expression* next_expr = nullptr;
switch (current_type_inst->opcode()) {
case SpvOpTypeVector: {
// Try generating a MemberAccessor expression. That result in something
// like "foo.z", which is more idiomatic than "foo[2]".
const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
if (num_elems <= index_val) {
Fail() << operation_name << " %" << inst.result_id() << " index value "
<< index_val << " is out of bounds for vector of " << num_elems
<< " elements";
return {};
}
if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
return {};
}
next_expr = create<ast::MemberAccessorExpression>(Source{}, current_expr.expr,
Swizzle(index_val));
// All vector components are the same type.
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
}
case SpvOpTypeMatrix: {
// Check bounds
const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
if (num_elems <= index_val) {
Fail() << operation_name << " %" << inst.result_id() << " index value "
<< index_val << " is out of bounds for matrix of " << num_elems
<< " elements";
return {};
}
if (index_val >= kMaxVectorLen) {
Fail() << "internal error: swizzle index " << index_val
<< " is too big. Max handled index is " << kMaxVectorLen - 1;
}
// Use array syntax.
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
make_index(index_val));
// All matrix components are the same type.
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
}
case SpvOpTypeArray:
// The array size could be a spec constant, and so it's not always
// statically checkable. Instead, rely on a runtime index clamp
// or runtime check to keep this safe.
next_expr = create<ast::IndexAccessorExpression>(Source{}, current_expr.expr,
make_index(index_val));
current_type_id = current_type_inst->GetSingleWordInOperand(0);
break;
case SpvOpTypeRuntimeArray:
Fail() << "can't do " << operation_name
<< " on a runtime array: " << inst.PrettyPrint();
return {};
case SpvOpTypeStruct: {
const auto num_members = current_type_inst->NumInOperands();
if (num_members <= index_val) {
Fail() << operation_name << " %" << inst.result_id() << " index value "
<< index_val << " is out of bounds for structure %" << current_type_id
<< " having " << num_members << " members";
return {};
}
auto name = namer_.GetMemberName(current_type_id, uint32_t(index_val));
auto* member_access =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
next_expr = create<ast::MemberAccessorExpression>(Source{}, current_expr.expr,
member_access);
current_type_id = current_type_inst->GetSingleWordInOperand(index_val);
break;
}
default:
Fail() << operation_name << " with bad type %" << current_type_id << ": "
<< current_type_inst->PrettyPrint();
return {};
}
current_expr = TypedExpression{parser_impl_.ConvertType(current_type_id), next_expr};
}
return current_expr;
}
const ast::Expression* FunctionEmitter::MakeTrue(const Source& source) const {
return create<ast::BoolLiteralExpression>(source, true);
}
const ast::Expression* FunctionEmitter::MakeFalse(const Source& source) const {
return create<ast::BoolLiteralExpression>(source, false);
}
TypedExpression FunctionEmitter::MakeVectorShuffle(const spvtools::opt::Instruction& inst) {
const auto vec0_id = inst.GetSingleWordInOperand(0);
const auto vec1_id = inst.GetSingleWordInOperand(1);
const spvtools::opt::Instruction& vec0 = *(def_use_mgr_->GetDef(vec0_id));
const spvtools::opt::Instruction& vec1 = *(def_use_mgr_->GetDef(vec1_id));
const auto vec0_len = type_mgr_->GetType(vec0.type_id())->AsVector()->element_count();
const auto vec1_len = type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
// Idiomatic vector accessors.
// Generate an ast::TypeConstructor expression.
// Assume the literal indices are valid, and there is a valid number of them.
auto source = GetSourceForInst(inst);
const Vector* result_type = As<Vector>(parser_impl_.ConvertType(inst.type_id()));
ast::ExpressionList values;
for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
const auto index = inst.GetSingleWordInOperand(i);
if (index < vec0_len) {
auto expr = MakeExpression(vec0_id);
if (!expr) {
return {};
}
values.emplace_back(
create<ast::MemberAccessorExpression>(source, expr.expr, Swizzle(index)));
} else if (index < vec0_len + vec1_len) {
const auto sub_index = index - vec0_len;
TINT_ASSERT(Reader, sub_index < kMaxVectorLen);
auto expr = MakeExpression(vec1_id);
if (!expr) {
return {};
}
values.emplace_back(
create<ast::MemberAccessorExpression>(source, expr.expr, Swizzle(sub_index)));
} else if (index == 0xFFFFFFFF) {
// By rule, this maps to OpUndef. Instead, make it zero.
values.emplace_back(parser_impl_.MakeNullValue(result_type->type));
} else {
Fail() << "invalid vectorshuffle ID %" << inst.result_id()
<< ": index too large: " << index;
return {};
}
}
return {result_type, builder_.Construct(source, result_type->Build(builder_), values)};
}
bool FunctionEmitter::RegisterSpecialBuiltInVariables() {
size_t index = def_info_.size();
for (auto& special_var : parser_impl_.special_builtins()) {
const auto id = special_var.first;
const auto builtin = special_var.second;
const auto* var = def_use_mgr_->GetDef(id);
def_info_[id] = std::make_unique<DefInfo>(*var, 0, index);
++index;
auto& def = def_info_[id];
switch (builtin) {
case SpvBuiltInPointSize:
def->skip = SkipReason::kPointSizeBuiltinPointer;
break;
case SpvBuiltInSampleMask: {
// Distinguish between input and output variable.
const auto storage_class =
static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
if (storage_class == SpvStorageClassInput) {
sample_mask_in_id = id;
def->skip = SkipReason::kSampleMaskInBuiltinPointer;
} else {
sample_mask_out_id = id;
def->skip = SkipReason::kSampleMaskOutBuiltinPointer;
}
break;
}
case SpvBuiltInSampleId:
case SpvBuiltInInstanceIndex:
case SpvBuiltInVertexIndex:
case SpvBuiltInLocalInvocationIndex:
case SpvBuiltInLocalInvocationId:
case SpvBuiltInGlobalInvocationId:
case SpvBuiltInWorkgroupId:
case SpvBuiltInNumWorkgroups:
break;
default:
return Fail() << "unrecognized special builtin: " << int(builtin);
}
}
return true;
}
bool FunctionEmitter::RegisterLocallyDefinedValues() {
// Create a DefInfo for each value definition in this function.
size_t index = def_info_.size();
for (auto block_id : block_order_) {
const auto* block_info = GetBlockInfo(block_id);
const auto block_pos = block_info->pos;
for (const auto& inst : *(block_info->basic_block)) {
const auto result_id = inst.result_id();
if ((result_id == 0) || inst.opcode() == SpvOpLabel) {
continue;
}
def_info_[result_id] = std::make_unique<DefInfo>(inst, block_pos, index);
++index;
auto& info = def_info_[result_id];
// Determine storage class for pointer values. Do this in order because
// we might rely on the storage class for a previously-visited definition.
// Logical pointers can't be transmitted through OpPhi, so remaining
// pointer definitions are SSA values, and their definitions must be
// visited before their uses.
const auto* type = type_mgr_->GetType(inst.type_id());
if (type) {
if (type->AsPointer()) {
if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
if (auto* ptr = ast_type->As<Pointer>()) {
info->storage_class = ptr->storage_class;
}
}
switch (inst.opcode()) {
case SpvOpUndef:
return Fail() << "undef pointer is not valid: " << inst.PrettyPrint();
case SpvOpVariable:
// Keep the default decision based on the result type.
break;
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
case SpvOpCopyObject:
// Inherit from the first operand. We need this so we can pick up
// a remapped storage buffer.
info->storage_class =
GetStorageClassForPointerValue(inst.GetSingleWordInOperand(0));
break;
default:
return Fail() << "pointer defined in function from unknown opcode: "
<< inst.PrettyPrint();
}
}
auto* unwrapped = type;
while (auto* ptr = unwrapped->AsPointer()) {
unwrapped = ptr->pointee_type();
}
if (unwrapped->AsSampler() || unwrapped->AsImage() || unwrapped->AsSampledImage()) {
// Defer code generation until the instruction that actually acts on
// the image.
info->skip = SkipReason::kOpaqueObject;
}
}
}
}
return true;
}
ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) {
auto where = def_info_.find(id);
if (where != def_info_.end()) {
auto candidate = where->second.get()->storage_class;
if (candidate != ast::StorageClass::kInvalid) {
return candidate;
}
}
const auto type_id = def_use_mgr_->GetDef(id)->type_id();
if (type_id) {
auto* ast_type = parser_impl_.ConvertType(type_id);
if (auto* ptr = As<Pointer>(ast_type)) {
return ptr->storage_class;
}
}
return ast::StorageClass::kInvalid;
}
const Type* FunctionEmitter::RemapStorageClass(const Type* type, uint32_t result_id) {
if (auto* ast_ptr_type = As<Pointer>(type)) {
// Remap an old-style storage buffer pointer to a new-style storage
// buffer pointer.
const auto sc = GetStorageClassForPointerValue(result_id);
if (ast_ptr_type->storage_class != sc) {
return ty_.Pointer(ast_ptr_type->type, sc);
}
}
return type;
}
void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
// Mark vector operands of OpVectorShuffle as needing a named definition,
// but only if they are defined in this function as well.
auto require_named_const_def = [&](const spvtools::opt::Instruction& inst,
int in_operand_index) {
const auto id = inst.GetSingleWordInOperand(in_operand_index);
auto* const operand_def = GetDefInfo(id);
if (operand_def) {
operand_def->requires_named_const_def = true;
}
};
for (auto& id_def_info_pair : def_info_) {
const auto& inst = id_def_info_pair.second->inst;
const auto opcode = inst.opcode();
if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
// We might access the vector operands multiple times. Make sure they
// are evaluated only once.
require_named_const_def(inst, 0);
require_named_const_def(inst, 1);
}
if (parser_impl_.IsGlslExtendedInstruction(inst)) {
// Some emulations of GLSLstd450 instructions evaluate certain operands
// multiple times. Ensure their expressions are evaluated only once.
switch (inst.GetSingleWordInOperand(1)) {
case GLSLstd450FaceForward:
// The "normal" operand expression is used twice in code generation.
require_named_const_def(inst, 2);
break;
case GLSLstd450Reflect:
require_named_const_def(inst, 2); // Incident
require_named_const_def(inst, 3); // Normal
break;
default:
break;
}
}
}
// Scan uses of locally defined IDs, in function block order.
for (auto block_id : block_order_) {
const auto* block_info = GetBlockInfo(block_id);
const auto block_pos = block_info->pos;
for (const auto& inst : *(block_info->basic_block)) {
// Update bookkeeping for locally-defined IDs used by this instruction.
inst.ForEachInId([this, block_pos, block_info](const uint32_t* id_ptr) {
auto* def_info = GetDefInfo(*id_ptr);
if (def_info) {
// Update usage count.
def_info->num_uses++;
// Update usage span.
def_info->last_use_pos = std::max(def_info->last_use_pos, block_pos);
// Determine whether this ID is defined in a different construct
// from this use.
const auto defining_block = block_order_[def_info->block_pos];
const auto* def_in_construct = GetBlockInfo(defining_block)->construct;
if (def_in_construct != block_info->construct) {
def_info->used_in_another_construct = true;
}
}
});
if (inst.opcode() == SpvOpPhi) {
// Declare a name for the variable used to carry values to a phi.
const auto phi_id = inst.result_id();
auto* phi_def_info = GetDefInfo(phi_id);
phi_def_info->phi_var = namer_.MakeDerivedName(namer_.Name(phi_id) + "_phi");
// Track all the places where we need to mention the variable,
// so we can place its declaration. First, record the location of
// the read from the variable.
uint32_t first_pos = block_pos;
uint32_t last_pos = block_pos;
// Record the assignments that will propagate values from predecessor
// blocks.
for (uint32_t i = 0; i + 1 < inst.NumInOperands(); i += 2) {
const uint32_t value_id = inst.GetSingleWordInOperand(i);
const uint32_t pred_block_id = inst.GetSingleWordInOperand(i + 1);
auto* pred_block_info = GetBlockInfo(pred_block_id);
// The predecessor might not be in the block order at all, so we
// need this guard.
if (IsInBlockOrder(pred_block_info)) {
// Record the assignment that needs to occur at the end
// of the predecessor block.
pred_block_info->phi_assignments.push_back({phi_id, value_id});
first_pos = std::min(first_pos, pred_block_info->pos);
last_pos = std::max(last_pos, pred_block_info->pos);
}
}
// Schedule the declaration of the state variable.
const auto* enclosing_construct = GetEnclosingScope(first_pos, last_pos);
GetBlockInfo(enclosing_construct->begin_id)
->phis_needing_state_vars.push_back(phi_id);
}
}
}
// For an ID defined in this function, determine if its evaluation and
// potential declaration needs special handling:
// - Compensate for the fact that dominance does not map directly to scope.
// A definition could dominate its use, but a named definition in WGSL
// at the location of the definition could go out of scope by the time
// you reach the use. In that case, we hoist the definition to a basic
// block at the smallest scope enclosing both the definition and all
// its uses.
// - If value is used in a different construct than its definition, then it
// needs a named constant definition. Otherwise we might sink an
// expensive computation into control flow, and hence change performance.
for (auto& id_def_info_pair : def_info_) {
const auto def_id = id_def_info_pair.first;
auto* def_info = id_def_info_pair.second.get();
if (def_info->num_uses == 0) {
// There is no need to adjust the location of the declaration.
continue;
}
// The first use must be the at the SSA definition, because block order
// respects dominance.
const auto first_pos = def_info->block_pos;
const auto last_use_pos = def_info->last_use_pos;
const auto* def_in_construct = GetBlockInfo(block_order_[first_pos])->construct;
// A definition in the first block of an kIfSelection or kSwitchSelection
// occurs before the branch, and so that definition should count as
// having been defined at the scope of the parent construct.
if (first_pos == def_in_construct->begin_pos) {
if ((def_in_construct->kind == Construct::kIfSelection) ||
(def_in_construct->kind == Construct::kSwitchSelection)) {
def_in_construct = def_in_construct->parent;
}
}
bool should_hoist = false;
if (!def_in_construct->ContainsPos(last_use_pos)) {
// To satisfy scoping, we have to hoist the definition out to an enclosing
// construct.
should_hoist = true;
} else {
// Avoid moving combinatorial values across constructs. This is a
// simple heuristic to avoid changing the cost of an operation
// by moving it into or out of a loop, for example.
if ((def_info->storage_class == ast::StorageClass::kInvalid) &&
def_info->used_in_another_construct) {
should_hoist = true;
}
}
if (should_hoist) {
const auto* enclosing_construct = GetEnclosingScope(first_pos, last_use_pos);
if (enclosing_construct == def_in_construct) {
// We can use a plain 'const' definition.
def_info->requires_named_const_def = true;
} else {
// We need to make a hoisted variable definition.
// TODO(dneto): Handle non-storable types, particularly pointers.
def_info->requires_hoisted_def = true;
auto* hoist_to_block = GetBlockInfo(enclosing_construct->begin_id);
hoist_to_block->hoisted_ids.push_back(def_id);
}
}
}
}
const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos, uint32_t last_pos) const {
const auto* enclosing_construct = GetBlockInfo(block_order_[first_pos])->construct;
TINT_ASSERT(Reader, enclosing_construct != nullptr);
// Constructs are strictly nesting, so follow parent pointers
while (enclosing_construct && !enclosing_construct->ScopeContainsPos(last_pos)) {
// The scope of a continue construct is enclosed in its associated loop
// construct, but they are siblings in our construct tree.
const auto* sibling_loop = SiblingLoopConstruct(enclosing_construct);
// Go to the sibling loop if it exists, otherwise walk up to the parent.
enclosing_construct = sibling_loop ? sibling_loop : enclosing_construct->parent;
}
// At worst, we go all the way out to the function construct.
TINT_ASSERT(Reader, enclosing_construct != nullptr);
return enclosing_construct;
}
TypedExpression FunctionEmitter::MakeNumericConversion(const spvtools::opt::Instruction& inst) {
const auto opcode = inst.opcode();
auto* requested_type = parser_impl_.ConvertType(inst.type_id());
auto arg_expr = MakeOperand(inst, 0);
if (!arg_expr) {
return {};
}
arg_expr.type = arg_expr.type->UnwrapRef();
const Type* expr_type = nullptr;
if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) {
if (arg_expr.type->IsIntegerScalarOrVector()) {
expr_type = requested_type;
} else {
Fail() << "operand for conversion to floating point must be integral "
"scalar or vector: "
<< inst.PrettyPrint();
}
} else if (inst.opcode() == SpvOpConvertFToU) {
if (arg_expr.type->IsFloatScalarOrVector()) {
expr_type = parser_impl_.GetUnsignedIntMatchingShape(arg_expr.type);
} else {
Fail() << "operand for conversion to unsigned integer must be floating "
"point scalar or vector: "
<< inst.PrettyPrint();
}
} else if (inst.opcode() == SpvOpConvertFToS) {
if (arg_expr.type->IsFloatScalarOrVector()) {
expr_type = parser_impl_.GetSignedIntMatchingShape(arg_expr.type);
} else {
Fail() << "operand for conversion to signed integer must be floating "
"point scalar or vector: "
<< inst.PrettyPrint();
}
}
if (expr_type == nullptr) {
// The diagnostic has already been emitted.
return {};
}
ast::ExpressionList params;
params.push_back(arg_expr.expr);
TypedExpression result{
expr_type,
builder_.Construct(GetSourceForInst(inst), expr_type->Build(builder_), std::move(params))};
if (requested_type == expr_type) {
return result;
}
return {requested_type,
create<ast::BitcastExpression>(GetSourceForInst(inst), requested_type->Build(builder_),
result.expr)};
}
bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
// We ignore function attributes such as Inline, DontInline, Pure, Const.
auto name = namer_.Name(inst.GetSingleWordInOperand(0));
auto* function = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
ast::ExpressionList args;
for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
auto expr = MakeOperand(inst, iarg);
if (!expr) {
return false;
}
// Functions cannot use references as parameters, so we need to pass by
// pointer if the operand is of pointer type.
expr = AddressOfIfNeeded(expr, def_use_mgr_->GetDef(inst.GetSingleWordInOperand(iarg)));
args.emplace_back(expr.expr);
}
if (failed()) {
return false;
}
auto* call_expr = create<ast::CallExpression>(Source{}, function, std::move(args));
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type) {
return Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
}
if (result_type->Is<Void>()) {
return nullptr != AddStatement(create<ast::CallStatement>(Source{}, call_expr));
}
return EmitConstDefOrWriteToHoistedVar(inst, {result_type, call_expr});
}
bool FunctionEmitter::EmitControlBarrier(const spvtools::opt::Instruction& inst) {
uint32_t operands[3];
for (int i = 0; i < 3; i++) {
auto id = inst.GetSingleWordInOperand(i);
if (auto* constant = constant_mgr_->FindDeclaredConstant(id)) {
operands[i] = constant->GetU32();
} else {
return Fail() << "invalid or missing operands for control barrier";
}
}
uint32_t execution = operands[0];
uint32_t memory = operands[1];
uint32_t semantics = operands[2];
if (execution != SpvScopeWorkgroup) {
return Fail() << "unsupported control barrier execution scope: "
<< "expected Workgroup (2), got: " << execution;
}
if (semantics & SpvMemorySemanticsAcquireReleaseMask) {
semantics &= ~SpvMemorySemanticsAcquireReleaseMask;
} else {
return Fail() << "control barrier semantics requires acquire and release";
}
if (semantics & SpvMemorySemanticsWorkgroupMemoryMask) {
if (memory != SpvScopeWorkgroup) {
return Fail() << "workgroupBarrier requires workgroup memory scope";
}
AddStatement(create<ast::CallStatement>(builder_.Call("workgroupBarrier")));
semantics &= ~SpvMemorySemanticsWorkgroupMemoryMask;
}
if (semantics & SpvMemorySemanticsUniformMemoryMask) {
if (memory != SpvScopeDevice) {
return Fail() << "storageBarrier requires device memory scope";
}
AddStatement(create<ast::CallStatement>(builder_.Call("storageBarrier")));
semantics &= ~SpvMemorySemanticsUniformMemoryMask;
}
if (semantics) {
return Fail() << "unsupported control barrier semantics: " << semantics;
}
return true;
}
TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instruction& inst) {
const auto builtin = GetBuiltin(inst.opcode());
auto* name = sem::str(builtin);
auto* ident = create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
ast::ExpressionList params;
const Type* first_operand_type = nullptr;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
TypedExpression operand = MakeOperand(inst, iarg);
if (first_operand_type == nullptr) {
first_operand_type = operand.type;
}
params.emplace_back(operand.expr);
}
auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(params));
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type) {
Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
return {};
}
TypedExpression call{result_type, call_expr};
return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
}
TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instruction& inst) {
auto condition = MakeOperand(inst, 0);
auto true_value = MakeOperand(inst, 1);
auto false_value = MakeOperand(inst, 2);
// SPIR-V validation requires:
// - the condition to be bool or bool vector, so we don't check it here.
// - true_value false_value, and result type to match.
// - you can't select over pointers or pointer vectors, unless you also have
// a VariablePointers* capability, which is not allowed in by WebGPU.
auto* op_ty = true_value.type;
if (op_ty->Is<Vector>() || op_ty->IsFloatScalar() || op_ty->IsIntegerScalar() ||
op_ty->Is<Bool>()) {
ast::ExpressionList params;
params.push_back(false_value.expr);
params.push_back(true_value.expr);
// The condition goes last.
params.push_back(condition.expr);
return {op_ty,
create<ast::CallExpression>(Source{},
create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("select")),
std::move(params))};
}
return {};
}
Source FunctionEmitter::GetSourceForInst(const spvtools::opt::Instruction& inst) const {
return parser_impl_.GetSourceForInst(&inst);
}
const spvtools::opt::Instruction* FunctionEmitter::GetImage(
const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() == 0) {
Fail() << "not an image access instruction: " << inst.PrettyPrint();
return nullptr;
}
// The image or sampled image operand is always the first operand.
const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
const auto* image =
parser_impl_.GetMemoryObjectDeclarationForHandle(image_or_sampled_image_operand_id, true);
if (!image) {
Fail() << "internal error: couldn't find image for " << inst.PrettyPrint();
return nullptr;
}
return image;
}
const Texture* FunctionEmitter::GetImageType(const spvtools::opt::Instruction& image) {
const Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image);
if (!parser_impl_.success()) {
Fail();
return {};
}
if (!ptr_type) {
Fail() << "invalid texture type for " << image.PrettyPrint();
return {};
}
auto* result = ptr_type->type->UnwrapAll()->As<Texture>();
if (!result) {
Fail() << "invalid texture type for " << image.PrettyPrint();
return {};
}
return result;
}
const ast::Expression* FunctionEmitter::GetImageExpression(const spvtools::opt::Instruction& inst) {
auto* image = GetImage(inst);
if (!image) {
return nullptr;
}
auto name = namer_.Name(image->result_id());
return create<ast::IdentifierExpression>(GetSourceForInst(inst),
builder_.Symbols().Register(name));
}
const ast::Expression* FunctionEmitter::GetSamplerExpression(
const spvtools::opt::Instruction& inst) {
// The sampled image operand is always the first operand.
const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
const auto* image =
parser_impl_.GetMemoryObjectDeclarationForHandle(image_or_sampled_image_operand_id, false);
if (!image) {
Fail() << "internal error: couldn't find sampler for " << inst.PrettyPrint();
return nullptr;
}
auto name = namer_.Name(image->result_id());
return create<ast::IdentifierExpression>(GetSourceForInst(inst),
builder_.Symbols().Register(name));
}
bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
ast::ExpressionList args;
const auto opcode = inst.opcode();
// Form the texture operand.
const spvtools::opt::Instruction* image = GetImage(inst);
if (!image) {
return false;
}
args.push_back(GetImageExpression(inst));
// Form the sampler operand, if needed.
if (IsSampledImageAccess(opcode)) {
// Form the sampler operand.
if (auto* sampler = GetSamplerExpression(inst)) {
args.push_back(sampler);
} else {
return false;
}
}
// Find the texture type.
const Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image);
if (!texture_ptr_type) {
return Fail();
}
const Texture* texture_type = texture_ptr_type->type->UnwrapAll()->As<Texture>();
if (!texture_type) {
return Fail();
}
// This is the SPIR-V operand index. We're done with the first operand.
uint32_t arg_index = 1;
// Push the coordinates operands.
auto coords = MakeCoordinateOperandsForImageAccess(inst);
if (coords.empty()) {
return false;
}
args.insert(args.end(), coords.begin(), coords.end());
// Skip the coordinates operand.
arg_index++;
const auto num_args = inst.NumInOperands();
// Consumes the depth-reference argument, pushing it onto the end of
// the parameter list. Issues a diagnostic and returns false on error.
auto consume_dref = [&]() -> bool {
if (arg_index < num_args) {
args.push_back(MakeOperand(inst, arg_index).expr);
arg_index++;
} else {
return Fail() << "image depth-compare instruction is missing a Dref operand: "
<< inst.PrettyPrint();
}
return true;
};
std::string builtin_name;
bool use_level_of_detail_suffix = true;
bool is_dref_sample = false;
bool is_gather_or_dref_gather = false;
bool is_non_dref_sample = false;
switch (opcode) {
case SpvOpImageSampleImplicitLod:
case SpvOpImageSampleExplicitLod:
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
is_non_dref_sample = true;
builtin_name = "textureSample";
break;
case SpvOpImageSampleDrefImplicitLod:
case SpvOpImageSampleDrefExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
is_dref_sample = true;
builtin_name = "textureSampleCompare";
if (!consume_dref()) {
return false;
}
break;
case SpvOpImageGather:
is_gather_or_dref_gather = true;
builtin_name = "textureGather";
if (!texture_type->Is<DepthTexture>()) {
// The explicit component is the *first* argument in WGSL.
args.insert(args.begin(), ToI32(MakeOperand(inst, arg_index)).expr);
}
// Skip over the component operand, even for depth textures.
arg_index++;
break;
case SpvOpImageDrefGather:
is_gather_or_dref_gather = true;
builtin_name = "textureGatherCompare";
if (!consume_dref()) {
return false;
}
break;
case SpvOpImageFetch:
case SpvOpImageRead:
// Read a single texel from a sampled or storage image.
builtin_name = "textureLoad";
use_level_of_detail_suffix = false;
break;
case SpvOpImageWrite:
builtin_name = "textureStore";
use_level_of_detail_suffix = false;
if (arg_index < num_args) {
auto texel = MakeOperand(inst, arg_index);
auto* converted_texel = ConvertTexelForStorage(inst, texel, texture_type);
if (!converted_texel) {
return false;
}
args.push_back(converted_texel);
arg_index++;
} else {
return Fail() << "image write is missing a Texel operand: " << inst.PrettyPrint();
}
break;
default:
return Fail() << "internal error: unrecognized image access: " << inst.PrettyPrint();
}
// Loop over the image operands, looking for extra operands to the builtin.
// Except we uroll the loop.
uint32_t image_operands_mask = 0;
if (arg_index < num_args) {
image_operands_mask = inst.GetSingleWordInOperand(arg_index);
arg_index++;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsBiasMask)) {
if (is_dref_sample) {
return Fail() << "WGSL does not support depth-reference sampling with "
"level-of-detail bias: "
<< inst.PrettyPrint();
}
if (is_gather_or_dref_gather) {
return Fail() << "WGSL does not support image gather with "
"level-of-detail bias: "
<< inst.PrettyPrint();
}
builtin_name += "Bias";
args.push_back(MakeOperand(inst, arg_index).expr);
image_operands_mask ^= SpvImageOperandsBiasMask;
arg_index++;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsLodMask)) {
if (use_level_of_detail_suffix) {
builtin_name += "Level";
}
if (is_dref_sample || is_gather_or_dref_gather) {
// Metal only supports Lod = 0 for comparison sampling without
// derivatives.
// Vulkan SPIR-V does not allow Lod with OpImageGather or
// OpImageDrefGather.
if (!IsFloatZero(inst.GetSingleWordInOperand(arg_index))) {
return Fail() << "WGSL comparison sampling without derivatives "
"requires level-of-detail 0.0"
<< inst.PrettyPrint();
}
// Don't generate the Lod argument.
} else {
// Generate the Lod argument.
TypedExpression lod = MakeOperand(inst, arg_index);
// When sampling from a depth texture, the Lod operand must be an I32.
if (texture_type->Is<DepthTexture>()) {
// Convert it to a signed integer type.
lod = ToI32(lod);
}
args.push_back(lod.expr);
}
image_operands_mask ^= SpvImageOperandsLodMask;
arg_index++;
} else if ((opcode == SpvOpImageFetch || opcode == SpvOpImageRead) &&
!texture_type->IsAnyOf<DepthMultisampledTexture, MultisampledTexture>()) {
// textureLoad requires an explicit level-of-detail parameter for
// non-multisampled texture types.
args.push_back(parser_impl_.MakeNullValue(ty_.I32()));
}
if (arg_index + 1 < num_args && (image_operands_mask & SpvImageOperandsGradMask)) {
if (is_dref_sample) {
return Fail() << "WGSL does not support depth-reference sampling with "
"explicit gradient: "
<< inst.PrettyPrint();
}
if (is_gather_or_dref_gather) {
return Fail() << "WGSL does not support image gather with "
"explicit gradient: "
<< inst.PrettyPrint();
}
builtin_name += "Grad";
args.push_back(MakeOperand(inst, arg_index).expr);
args.push_back(MakeOperand(inst, arg_index + 1).expr);
image_operands_mask ^= SpvImageOperandsGradMask;
arg_index += 2;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsConstOffsetMask)) {
if (!IsImageSamplingOrGatherOrDrefGather(opcode)) {
return Fail() << "ConstOffset is only permitted for sampling, gather, or "
"depth-reference gather operations: "
<< inst.PrettyPrint();
}
switch (texture_type->dims) {
case ast::TextureDimension::k2d:
case ast::TextureDimension::k2dArray:
case ast::TextureDimension::k3d:
break;
default:
return Fail() << "ConstOffset is only permitted for 2D, 2D Arrayed, "
"and 3D textures: "
<< inst.PrettyPrint();
}
args.push_back(ToSignedIfUnsigned(MakeOperand(inst, arg_index)).expr);
image_operands_mask ^= SpvImageOperandsConstOffsetMask;
arg_index++;
}
if (arg_index < num_args && (image_operands_mask & SpvImageOperandsSampleMask)) {
// TODO(dneto): only permitted with ImageFetch
args.push_back(ToI32(MakeOperand(inst, arg_index)).expr);
image_operands_mask ^= SpvImageOperandsSampleMask;
arg_index++;
}
if (image_operands_mask) {
return Fail() << "unsupported image operands (" << image_operands_mask
<< "): " << inst.PrettyPrint();
}
// If any of the arguments are nullptr, then we've failed.
if (std::any_of(args.begin(), args.end(), [](auto* expr) { return expr == nullptr; })) {
return false;
}
auto* ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(builtin_name));
auto* call_expr = create<ast::CallExpression>(Source{}, ident, std::move(args));
if (inst.type_id() != 0) {
// It returns a value.
const ast::Expression* value = call_expr;
// The result type, derived from the SPIR-V instruction.
auto* result_type = parser_impl_.ConvertType(inst.type_id());
auto* result_component_type = result_type;
if (auto* result_vector_type = As<Vector>(result_type)) {
result_component_type = result_vector_type->type;
}
// For depth textures, the arity might mot match WGSL:
// Operation SPIR-V WGSL
// normal sampling vec4 ImplicitLod f32
// normal sampling vec4 ExplicitLod f32
// compare sample f32 DrefImplicitLod f32
// compare sample f32 DrefExplicitLod f32
// texel load vec4 ImageFetch f32
// normal gather vec4 ImageGather vec4
// dref gather vec4 ImageDrefGather vec4
// Construct a 4-element vector with the result from the builtin in the
// first component.
if (texture_type->IsAnyOf<DepthTexture, DepthMultisampledTexture>()) {
if (is_non_dref_sample || (opcode == SpvOpImageFetch)) {
value = builder_.Construct(
Source{},
result_type->Build(builder_), // a vec4
ast::ExpressionList{value, parser_impl_.MakeNullValue(result_component_type),
parser_impl_.MakeNullValue(result_component_type),
parser_impl_.MakeNullValue(result_component_type)});
}
}
// If necessary, convert the result to the signedness of the instruction
// result type. Compare the SPIR-V image's sampled component type with the
// component of the result type of the SPIR-V instruction.
auto* spirv_image_type = parser_impl_.GetSpirvTypeForHandleMemoryObjectDeclaration(*image);
if (!spirv_image_type || (spirv_image_type->opcode() != SpvOpTypeImage)) {
return Fail() << "invalid image type for image memory object declaration "
<< image->PrettyPrint();
}
auto* expected_component_type =
parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0));
if (expected_component_type != result_component_type) {
// This occurs if one is signed integer and the other is unsigned integer,
// or vice versa. Perform a bitcast.
value =
create<ast::BitcastExpression>(Source{}, result_type->Build(builder_), call_expr);
}
if (!expected_component_type->Is<F32>() && IsSampledImageAccess(opcode)) {
// WGSL permits sampled image access only on float textures.
// Reject this case in the SPIR-V reader, at least until SPIR-V validation
// catches up with this rule and can reject it earlier in the workflow.
return Fail() << "sampled image must have float component type";
}
EmitConstDefOrWriteToHoistedVar(inst, {result_type, value});
} else {
// It's an image write. No value is returned, so make a statement out
// of the call.
AddStatement(create<ast::CallStatement>(Source{}, call_expr));
}
return success();
}
bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
// TODO(dneto): Reject cases that are valid in Vulkan but invalid in WGSL.
const spvtools::opt::Instruction* image = GetImage(inst);
if (!image) {
return false;
}
auto* texture_type = GetImageType(*image);
if (!texture_type) {
return false;
}
const auto opcode = inst.opcode();
switch (opcode) {
case SpvOpImageQuerySize:
case SpvOpImageQuerySizeLod: {
ast::ExpressionList exprs;
// Invoke textureDimensions.
// If the texture is arrayed, combine with the result from
// textureNumLayers.
auto* dims_ident = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("textureDimensions"));
ast::ExpressionList dims_args{GetImageExpression(inst)};
if (opcode == SpvOpImageQuerySizeLod) {
dims_args.push_back(ToI32(MakeOperand(inst, 1)).expr);
}
const ast::Expression* dims_call =
create<ast::CallExpression>(Source{}, dims_ident, dims_args);
auto dims = texture_type->dims;
if ((dims == ast::TextureDimension::kCube) ||
(dims == ast::TextureDimension::kCubeArray)) {
// textureDimension returns a 3-element vector but SPIR-V expects 2.
dims_call =
create<ast::MemberAccessorExpression>(Source{}, dims_call, PrefixSwizzle(2));
}
exprs.push_back(dims_call);
if (ast::IsTextureArray(dims)) {
auto* layers_ident = create<ast::IdentifierExpression>(
Source{}, builder_.Symbols().Register("textureNumLayers"));
exprs.push_back(create<ast::CallExpression>(
Source{}, layers_ident, ast::ExpressionList{GetImageExpression(inst)}));
}
auto* result_type = parser_impl_.ConvertType(inst.type_id());
TypedExpression expr = {
result_type, builder_.Construct(Source{}, result_type->Build(builder_), exprs)};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
case SpvOpImageQueryLod:
return Fail() << "WGSL does not support querying the level of detail of "
"an image: "
<< inst.PrettyPrint();
case SpvOpImageQueryLevels:
case SpvOpImageQuerySamples: {
const auto* name =
(opcode == SpvOpImageQueryLevels) ? "textureNumLevels" : "textureNumSamples";
auto* levels_ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(name));
const ast::Expression* ast_expr = create<ast::CallExpression>(
Source{}, levels_ident, ast::ExpressionList{GetImageExpression(inst)});
auto* result_type = parser_impl_.ConvertType(inst.type_id());
// The SPIR-V result type must be integer scalar. The WGSL bulitin
// returns i32. If they aren't the same then convert the result.
if (!result_type->Is<I32>()) {
ast_expr = builder_.Construct(Source{}, result_type->Build(builder_),
ast::ExpressionList{ast_expr});
}
TypedExpression expr{result_type, ast_expr};
return EmitConstDefOrWriteToHoistedVar(inst, expr);
}
default:
break;
}
return Fail() << "unhandled image query: " << inst.PrettyPrint();
}
ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
const spvtools::opt::Instruction& inst) {
if (!parser_impl_.success()) {
Fail();
return {};
}
const spvtools::opt::Instruction* image = GetImage(inst);
if (!image) {
return {};
}
if (inst.NumInOperands() < 1) {
Fail() << "image access is missing a coordinate parameter: " << inst.PrettyPrint();
return {};
}
// In SPIR-V for Shader, coordinates are:
// - floating point for sampling, dref sampling, gather, dref gather
// - integral for fetch, read, write
// In WGSL:
// - floating point for sampling, dref sampling, gather, dref gather
// - signed integral for textureLoad, textureStore
//
// The only conversions we have to do for WGSL are:
// - When the coordinates are unsigned integral, convert them to signed.
// - Array index is always i32
// The coordinates parameter is always in position 1.
TypedExpression raw_coords(MakeOperand(inst, 1));
if (!raw_coords) {
return {};
}
const Texture* texture_type = GetImageType(*image);
if (!texture_type) {
return {};
}
ast::TextureDimension dim = texture_type->dims;
// Number of regular coordinates.
uint32_t num_axes = ast::NumCoordinateAxes(dim);
bool is_arrayed = ast::IsTextureArray(dim);
if ((num_axes == 0) || (num_axes > 3)) {
Fail() << "unsupported image dimensionality for " << texture_type->TypeInfo().name
<< " prompted by " << inst.PrettyPrint();
}
bool is_proj = false;
switch (inst.opcode()) {
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
is_proj = true;
break;
default:
break;
}
const auto num_coords_required = num_axes + (is_arrayed ? 1 : 0) + (is_proj ? 1 : 0);
uint32_t num_coords_supplied = 0;
auto* component_type = raw_coords.type;
if (component_type->IsFloatScalar() || component_type->IsIntegerScalar()) {
num_coords_supplied = 1;
} else if (auto* vec_type = As<Vector>(raw_coords.type)) {
component_type = vec_type->type;
num_coords_supplied = vec_type->size;
}
if (num_coords_supplied == 0) {
Fail() << "bad or unsupported coordinate type for image access: " << inst.PrettyPrint();
return {};
}
if (num_coords_required > num_coords_supplied) {
Fail() << "image access required " << num_coords_required
<< " coordinate components, but only " << num_coords_supplied
<< " provided, in: " << inst.PrettyPrint();
return {};
}
ast::ExpressionList result;
// Generates the expression for the WGSL coordinates, when it is a prefix
// swizzle with num_axes. If the result would be unsigned, also converts
// it to a signed value of the same shape (scalar or vector).
// Use a lambda to make it easy to only generate the expressions when we
// will actually use them.
auto prefix_swizzle_expr = [this, num_axes, component_type, is_proj,
raw_coords]() -> const ast::Expression* {
auto* swizzle_type =
(num_axes == 1) ? component_type : ty_.Vector(component_type, num_axes);
auto* swizzle = create<ast::MemberAccessorExpression>(Source{}, raw_coords.expr,
PrefixSwizzle(num_axes));
if (is_proj) {
auto* q =
create<ast::MemberAccessorExpression>(Source{}, raw_coords.expr, Swizzle(num_axes));
auto* proj_div = builder_.Div(swizzle, q);
return ToSignedIfUnsigned({swizzle_type, proj_div}).expr;
} else {
return ToSignedIfUnsigned({swizzle_type, swizzle}).expr;
}
};
if (is_arrayed) {
// The source must be a vector. It has at least one coordinate component
// and it must have an array component. Use a vector swizzle to get the
// first `num_axes` components.
result.push_back(prefix_swizzle_expr());
// Now get the array index.
const ast::Expression* array_index =
builder_.MemberAccessor(raw_coords.expr, Swizzle(num_axes));
if (component_type->IsFloatScalar()) {
// When converting from a float array layer to integer, Vulkan requires
// round-to-nearest, with preference for round-to-nearest-even.
// But i32(f32) in WGSL has unspecified rounding mode, so we have to
// explicitly specify the rounding.
array_index = builder_.Call("round", array_index);
}
// Convert it to a signed integer type, if needed.
result.push_back(ToI32({component_type, array_index}).expr);
} else {
if (num_coords_supplied == num_coords_required && !is_proj) {
// Pass the value through, with possible unsigned->signed conversion.
result.push_back(ToSignedIfUnsigned(raw_coords).expr);
} else {
// There are more coordinates supplied than needed. So the source type
// is a vector. Use a vector swizzle to get the first `num_axes`
// components.
result.push_back(prefix_swizzle_expr());
}
}
return result;
}
const ast::Expression* FunctionEmitter::ConvertTexelForStorage(
const spvtools::opt::Instruction& inst,
TypedExpression texel,
const Texture* texture_type) {
auto* storage_texture_type = As<StorageTexture>(texture_type);
auto* src_type = texel.type;
if (!storage_texture_type) {
Fail() << "writing to other than storage texture: " << inst.PrettyPrint();
return nullptr;
}
const auto format = storage_texture_type->format;
auto* dest_type = parser_impl_.GetTexelTypeForFormat(format);
if (!dest_type) {
Fail();
return nullptr;
}
// The texel type is always a 4-element vector.
const uint32_t dest_count = 4u;
TINT_ASSERT(Reader, dest_type->Is<Vector>() && dest_type->As<Vector>()->size == dest_count);
TINT_ASSERT(Reader, dest_type->IsFloatVector() || dest_type->IsUnsignedIntegerVector() ||
dest_type->IsSignedIntegerVector());
if (src_type == dest_type) {
return texel.expr;
}
// Component type must match floatness, or integral signedness.
if ((src_type->IsFloatScalarOrVector() != dest_type->IsFloatVector()) ||
(src_type->IsUnsignedIntegerVector() != dest_type->IsUnsignedIntegerVector()) ||
(src_type->IsSignedIntegerVector() != dest_type->IsSignedIntegerVector())) {
Fail() << "invalid texel type for storage texture write: component must be "
"float, signed integer, or unsigned integer "
"to match the texture channel type: "
<< inst.PrettyPrint();
return nullptr;
}
const auto required_count = parser_impl_.GetChannelCountForFormat(format);
TINT_ASSERT(Reader, 0 < required_count && required_count <= 4);
const uint32_t src_count = src_type->IsScalar() ? 1 : src_type->As<Vector>()->size;
if (src_count < required_count) {
Fail() << "texel has too few components for storage texture: " << src_count
<< " provided but " << required_count << " required, in: " << inst.PrettyPrint();
return nullptr;
}
// It's valid for required_count < src_count. The extra components will
// be written out but the textureStore will ignore them.
if (src_count < dest_count) {
// Expand the texel to a 4 element vector.
auto* component_type = texel.type->IsScalar() ? texel.type : texel.type->As<Vector>()->type;
texel.type = ty_.Vector(component_type, dest_count);
ast::ExpressionList exprs;
exprs.push_back(texel.expr);
for (auto i = src_count; i < dest_count; i++) {
exprs.push_back(parser_impl_.MakeNullExpression(component_type).expr);
}
texel.expr = builder_.Construct(Source{}, texel.type->Build(builder_), std::move(exprs));
}
return texel.expr;
}
TypedExpression FunctionEmitter::ToI32(TypedExpression value) {
if (!value || value.type->Is<I32>()) {
return value;
}
return {ty_.I32(),
builder_.Construct(Source{}, builder_.ty.i32(), ast::ExpressionList{value.expr})};
}
TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
if (!value || !value.type->IsUnsignedScalarOrVector()) {
return value;
}
if (auto* vec_type = value.type->As<Vector>()) {
auto* new_type = ty_.Vector(ty_.I32(), vec_type->size);
return {new_type,
builder_.Construct(new_type->Build(builder_), ast::ExpressionList{value.expr})};
}
return ToI32(value);
}
TypedExpression FunctionEmitter::MakeArrayLength(const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() != 2) {
// Binary parsing will fail on this anyway.
Fail() << "invalid array length: requires 2 operands: " << inst.PrettyPrint();
return {};
}
const auto struct_ptr_id = inst.GetSingleWordInOperand(0);
const auto field_index = inst.GetSingleWordInOperand(1);
const auto struct_ptr_type_id = def_use_mgr_->GetDef(struct_ptr_id)->type_id();
// Trace through the pointer type to get to the struct type.
const auto struct_type_id = def_use_mgr_->GetDef(struct_ptr_type_id)->GetSingleWordInOperand(1);
const auto field_name = namer_.GetMemberName(struct_type_id, field_index);
if (field_name.empty()) {
Fail() << "struct index out of bounds for array length: " << inst.PrettyPrint();
return {};
}
auto member_expr = MakeExpression(struct_ptr_id);
if (!member_expr) {
return {};
}
if (member_expr.type->Is<Pointer>()) {
member_expr = Dereference(member_expr);
}
auto* member_ident =
create<ast::IdentifierExpression>(Source{}, builder_.Symbols().Register(field_name));
auto* member_access =
create<ast::MemberAccessorExpression>(Source{}, member_expr.expr, member_ident);
// Generate the builtin function call.
auto* call_expr = builder_.Call(Source{}, "arrayLength", builder_.AddressOf(member_access));
return {parser_impl_.ConvertType(inst.type_id()), call_expr};
}
TypedExpression FunctionEmitter::MakeOuterProduct(const spvtools::opt::Instruction& inst) {
// Synthesize the result.
auto col = MakeOperand(inst, 0);
auto row = MakeOperand(inst, 1);
auto* col_ty = As<Vector>(col.type);
auto* row_ty = As<Vector>(row.type);
auto* result_ty = As<Matrix>(parser_impl_.ConvertType(inst.type_id()));
if (!col_ty || !col_ty || !result_ty || result_ty->type != col_ty->type ||
result_ty->type != row_ty->type || result_ty->columns != row_ty->size ||
result_ty->rows != col_ty->size) {
Fail() << "invalid outer product instruction: bad types " << inst.PrettyPrint();
return {};
}
// Example:
// c : vec3 column vector
// r : vec2 row vector
// OuterProduct c r : mat2x3 (2 columns, 3 rows)
// Result:
// | c.x * r.x c.x * r.y |
// | c.y * r.x c.y * r.y |
// | c.z * r.x c.z * r.y |
ast::ExpressionList result_columns;
for (uint32_t icol = 0; icol < result_ty->columns; icol++) {
ast::ExpressionList result_row;
auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr, Swizzle(icol));
for (uint32_t irow = 0; irow < result_ty->rows; irow++) {
auto* column_factor =
create<ast::MemberAccessorExpression>(Source{}, col.expr, Swizzle(irow));
auto* elem = create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kMultiply,
row_factor, column_factor);
result_row.push_back(elem);
}
result_columns.push_back(builder_.Construct(Source{}, col_ty->Build(builder_), result_row));
}
return {result_ty, builder_.Construct(Source{}, result_ty->Build(builder_), result_columns)};
}
bool FunctionEmitter::MakeVectorInsertDynamic(const spvtools::opt::Instruction& inst) {
// For
// %result = OpVectorInsertDynamic %type %src_vector %component %index
// there are two cases.
//
// Case 1:
// The %src_vector value has already been hoisted into a variable.
// In this case, assign %src_vector to that variable, then write the
// component into the right spot:
//
// hoisted = src_vector;
// hoisted[index] = component;
//
// Case 2:
// The %src_vector value is not hoisted. In this case, make a temporary
// variable with the %src_vector contents, then write the component,
// and then make a let-declaration that reads the value out:
//
// var temp : type = src_vector;
// temp[index] = component;
// let result : type = temp;
//
// Then use result everywhere the original SPIR-V id is used. Using a const
// like this avoids constantly reloading the value many times.
auto* type = parser_impl_.ConvertType(inst.type_id());
auto src_vector = MakeOperand(inst, 0);
auto component = MakeOperand(inst, 1);
auto index = MakeOperand(inst, 2);
std::string var_name;
auto original_value_name = namer_.Name(inst.result_id());
const bool hoisted = WriteIfHoistedVar(inst, src_vector);
if (hoisted) {
// The variable was already declared in an earlier block.
var_name = original_value_name;
// Assign the source vector value to it.
builder_.Assign({}, builder_.Expr(var_name), src_vector.expr);
} else {
// Synthesize the temporary variable.
// It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
// API in parser_impl_.
var_name = namer_.MakeDerivedName(original_value_name);
auto* temp_var = builder_.Var(var_name, type->Build(builder_), ast::StorageClass::kNone,
src_vector.expr);
AddStatement(builder_.Decl({}, temp_var));
}
auto* lhs = create<ast::IndexAccessorExpression>(Source{}, builder_.Expr(var_name), index.expr);
if (!lhs) {
return false;
}
AddStatement(builder_.Assign(lhs, component.expr));
if (hoisted) {
// The hoisted variable itself stands for this result ID.
return success();
}
// Create a new let-declaration that is initialized by the contents
// of the temporary variable.
return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
}
bool FunctionEmitter::MakeCompositeInsert(const spvtools::opt::Instruction& inst) {
// For
// %result = OpCompositeInsert %type %object %composite 1 2 3 ...
// there are two cases.
//
// Case 1:
// The %composite value has already been hoisted into a variable.
// In this case, assign %composite to that variable, then write the
// component into the right spot:
//
// hoisted = composite;
// hoisted[index].x = object;
//
// Case 2:
// The %composite value is not hoisted. In this case, make a temporary
// variable with the %composite contents, then write the component,
// and then make a let-declaration that reads the value out:
//
// var temp : type = composite;
// temp[index].x = object;
// let result : type = temp;
//
// Then use result everywhere the original SPIR-V id is used. Using a const
// like this avoids constantly reloading the value many times.
//
// This technique is a combination of:
// - making a temporary variable and constant declaration, like what we do
// for VectorInsertDynamic, and
// - building up an access-chain like access like for CompositeExtract, but
// on the left-hand side of the assignment.
auto* type = parser_impl_.ConvertType(inst.type_id());
auto component = MakeOperand(inst, 0);
auto src_composite = MakeOperand(inst, 1);
std::string var_name;
auto original_value_name = namer_.Name(inst.result_id());
const bool hoisted = WriteIfHoistedVar(inst, src_composite);
if (hoisted) {
// The variable was already declared in an earlier block.
var_name = original_value_name;
// Assign the source composite value to it.
builder_.Assign({}, builder_.Expr(var_name), src_composite.expr);
} else {
// Synthesize a temporary variable.
// It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
// API in parser_impl_.
var_name = namer_.MakeDerivedName(original_value_name);
auto* temp_var = builder_.Var(var_name, type->Build(builder_), ast::StorageClass::kNone,
src_composite.expr);
AddStatement(builder_.Decl({}, temp_var));
}
TypedExpression seed_expr{type, builder_.Expr(var_name)};
// The left-hand side of the assignment *looks* like a decomposition.
TypedExpression lhs = MakeCompositeValueDecomposition(inst, seed_expr, inst.type_id(), 2);
if (!lhs) {
return false;
}
AddStatement(builder_.Assign(lhs.expr, component.expr));
if (hoisted) {
// The hoisted variable itself stands for this result ID.
return success();
}
// Create a new let-declaration that is initialized by the contents
// of the temporary variable.
return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
}
TypedExpression FunctionEmitter::AddressOf(TypedExpression expr) {
auto* ref = expr.type->As<Reference>();
if (!ref) {
Fail() << "AddressOf() called on non-reference type";
return {};
}
return {
ty_.Pointer(ref->type, ref->storage_class),
create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kAddressOf, expr.expr),
};
}
TypedExpression FunctionEmitter::Dereference(TypedExpression expr) {
auto* ptr = expr.type->As<Pointer>();
if (!ptr) {
Fail() << "Dereference() called on non-pointer type";
return {};
}
return {
ptr->type,
create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kIndirection, expr.expr),
};
}
bool FunctionEmitter::IsFloatZero(uint32_t value_id) {
if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
if (const auto* float_const = c->AsFloatConstant()) {
return 0.0f == float_const->GetFloatValue();
}
if (c->AsNullConstant()) {
// Valid SPIR-V requires it to be a float value anyway.
return true;
}
}
return false;
}
bool FunctionEmitter::IsFloatOne(uint32_t value_id) {
if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
if (const auto* float_const = c->AsFloatConstant()) {
return 1.0f == float_const->GetFloatValue();
}
}
return false;
}
FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
} // namespace tint::reader::spirv
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::StatementBuilder);
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::SwitchStatementBuilder);
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::IfStatementBuilder);
TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::LoopStatementBuilder);