mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-05-13 10:51:35 +00:00
Implement new transform UnwindDiscardFunctions that replaces discard statements with setting a module-level bool, adds a check and return for this bool after every function call that may discard, and finally invokes a single function that executes a discard from top-level functions. Regenerated tests and remove HLSL ones that used to fail FXC because it had difficulty with discard. Bug: tint:1478 Bug: chromium:1118 Change-Id: I09d680f59e2d5d0cad907bfbbdd426aae76d4bf3 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/84221 Reviewed-by: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
433 lines
15 KiB
C++
433 lines
15 KiB
C++
// Copyright 2022 The Tint Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "src/tint/transform/unwind_discard_functions.h"
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "src/tint/ast/discard_statement.h"
|
|
#include "src/tint/ast/return_statement.h"
|
|
#include "src/tint/ast/traverse_expressions.h"
|
|
#include "src/tint/sem/block_statement.h"
|
|
#include "src/tint/sem/call.h"
|
|
#include "src/tint/sem/for_loop_statement.h"
|
|
#include "src/tint/sem/function.h"
|
|
#include "src/tint/sem/if_statement.h"
|
|
|
|
TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions);
|
|
|
|
namespace tint::transform {
|
|
namespace {
|
|
|
|
class State {
|
|
private:
|
|
CloneContext& ctx;
|
|
ProgramBuilder& b;
|
|
const sem::Info& sem;
|
|
Symbol module_discard_var_name; // Use ModuleDiscardVarName() to read
|
|
Symbol module_discard_func_name; // Use ModuleDiscardFuncName() to read
|
|
|
|
// For the input statement, returns the block and statement within that
|
|
// block to insert before/after.
|
|
std::pair<const sem::BlockStatement*, const ast::Statement*>
|
|
GetInsertionPoint(const ast::Statement* stmt) {
|
|
using RetType =
|
|
std::pair<const sem::BlockStatement*, const ast::Statement*>;
|
|
|
|
if (auto* sem_stmt = sem.Get(stmt)) {
|
|
auto* parent = sem_stmt->Parent();
|
|
return Switch(
|
|
parent,
|
|
[&](const sem::BlockStatement* block) -> RetType {
|
|
// Common case, just insert in the current block above the input
|
|
// statement.
|
|
return {block, stmt};
|
|
},
|
|
[&](const sem::ForLoopStatement* fl) -> RetType {
|
|
// `stmt` is either the for loop initializer or the continuing
|
|
// statement of a for-loop.
|
|
if (fl->Declaration()->initializer == stmt) {
|
|
// For loop init, insert above the for loop itself.
|
|
return {fl->Block(), fl->Declaration()};
|
|
}
|
|
|
|
TINT_ICE(Transform, b.Diagnostics())
|
|
<< "cannot insert before or after continuing statement of a "
|
|
"for-loop";
|
|
return {};
|
|
},
|
|
[&](Default) -> RetType {
|
|
TINT_ICE(Transform, b.Diagnostics())
|
|
<< "expected parent of statement to be either a block or for "
|
|
"loop";
|
|
return {};
|
|
});
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
// If `block`'s parent is of type TO, returns pointer to it.
|
|
template <typename TO>
|
|
const TO* ParentAs(const ast::BlockStatement* block) {
|
|
if (auto* sem_block = sem.Get(block)) {
|
|
return As<TO>(sem_block->Parent());
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
// Returns true if `sem_expr` contains a call expression that may
|
|
// (transitively) execute a discard statement.
|
|
bool MayDiscard(const sem::Expression* sem_expr) {
|
|
return sem_expr && sem_expr->Behaviors().Contains(sem::Behavior::kDiscard);
|
|
}
|
|
|
|
// Lazily creates and returns the name of the module bool variable for whether
|
|
// to discard: "tint_discard".
|
|
Symbol ModuleDiscardVarName() {
|
|
if (!module_discard_var_name.IsValid()) {
|
|
module_discard_var_name = b.Symbols().New("tint_discard");
|
|
ctx.dst->Global(module_discard_var_name, b.ty.bool_(), b.Expr(false),
|
|
ast::StorageClass::kPrivate);
|
|
}
|
|
return module_discard_var_name;
|
|
}
|
|
|
|
// Lazily creates and returns the name of the function that contains a single
|
|
// discard statement: "tint_discard_func".
|
|
// We do this to avoid having multiple discard statements in a single program,
|
|
// which causes problems in certain backends (see crbug.com/1118).
|
|
Symbol ModuleDiscardFuncName() {
|
|
if (!module_discard_func_name.IsValid()) {
|
|
module_discard_func_name = b.Symbols().New("tint_discard_func");
|
|
b.Func(module_discard_func_name, {}, b.ty.void_(), {b.Discard()});
|
|
}
|
|
return module_discard_func_name;
|
|
}
|
|
|
|
// Creates "return <default return value>;" based on the return type of
|
|
// `stmt`'s owning function.
|
|
const ast::ReturnStatement* Return(const ast::Statement* stmt) {
|
|
const ast::Expression* ret_val = nullptr;
|
|
auto* ret_type = sem.Get(stmt)->Function()->Declaration()->return_type;
|
|
if (!ret_type->Is<ast::Void>()) {
|
|
ret_val = b.Construct(ctx.Clone(ret_type));
|
|
}
|
|
return b.Return(ret_val);
|
|
}
|
|
|
|
// Returns true if the function `stmt` is in is an entry point
|
|
bool IsInEntryPointFunc(const ast::Statement* stmt) {
|
|
return sem.Get(stmt)->Function()->Declaration()->IsEntryPoint();
|
|
}
|
|
|
|
// Creates "tint_discard_func();"
|
|
const ast::CallStatement* CallDiscardFunc() {
|
|
auto func_name = ModuleDiscardFuncName();
|
|
return b.CallStmt(b.Call(func_name));
|
|
}
|
|
|
|
// Creates and returns a new if-statement of the form:
|
|
//
|
|
// if (tint_discard) {
|
|
// return <default value>;
|
|
// }
|
|
//
|
|
// or if `stmt` is in a entry point function:
|
|
//
|
|
// if (tint_discard) {
|
|
// tint_discard_func();
|
|
// return <default value>;
|
|
// }
|
|
//
|
|
const ast::IfStatement* IfDiscardReturn(const ast::Statement* stmt) {
|
|
ast::StatementList stmts;
|
|
|
|
// For entry point functions, also emit the discard statement
|
|
if (IsInEntryPointFunc(stmt)) {
|
|
stmts.emplace_back(CallDiscardFunc());
|
|
}
|
|
|
|
stmts.emplace_back(Return(stmt));
|
|
|
|
auto var_name = ModuleDiscardVarName();
|
|
return b.If(var_name, b.Block(stmts));
|
|
}
|
|
|
|
// Hoists `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
|
|
// For example, if `stmt` is:
|
|
//
|
|
// return f();
|
|
//
|
|
// This function will transform this to:
|
|
//
|
|
// let t1 = f();
|
|
// if (tint_discard) {
|
|
// return;
|
|
// }
|
|
// return t1;
|
|
//
|
|
const ast::Statement* HoistAndInsertBefore(const ast::Statement* stmt,
|
|
const sem::Expression* sem_expr) {
|
|
auto* expr = sem_expr->Declaration();
|
|
|
|
auto ip = GetInsertionPoint(stmt);
|
|
auto var_name = b.Sym();
|
|
auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr)));
|
|
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
|
|
|
|
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second,
|
|
IfDiscardReturn(stmt));
|
|
|
|
auto* var_expr = b.Expr(var_name);
|
|
|
|
// Special handling for CallStatement as we can only replace its expression
|
|
// with a CallExpression.
|
|
if (stmt->Is<ast::CallStatement>()) {
|
|
// We could replace the call statement with no statement, but we can't do
|
|
// that with transforms (yet), so just return a phony assignment.
|
|
return b.Assign(b.Phony(), var_expr);
|
|
}
|
|
|
|
ctx.Replace(expr, var_expr);
|
|
return ctx.CloneWithoutTransform(stmt);
|
|
}
|
|
|
|
// Returns true if `stmt` is a for-loop initializer statement.
|
|
bool IsForLoopInitStatement(const ast::Statement* stmt) {
|
|
if (auto* sem_stmt = sem.Get(stmt)) {
|
|
if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) {
|
|
return sem_fl->Declaration()->initializer == stmt;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Inserts an `IfDiscardReturn` after `stmt` if possible (i.e. `stmt` is not
|
|
// in a for-loop init), otherwise falls back to HoistAndInsertBefore, hoisting
|
|
// `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
|
|
//
|
|
// For example, if `stmt` is:
|
|
//
|
|
// let r = f();
|
|
//
|
|
// This function will transform this to:
|
|
//
|
|
// let r = f();
|
|
// if (tint_discard) {
|
|
// return;
|
|
// }
|
|
const ast::Statement* TryInsertAfter(const ast::Statement* stmt,
|
|
const sem::Expression* sem_expr) {
|
|
// If `stmt` is the init of a for-loop, hoist and insert before instead.
|
|
if (IsForLoopInitStatement(stmt)) {
|
|
return HoistAndInsertBefore(stmt, sem_expr);
|
|
}
|
|
|
|
auto ip = GetInsertionPoint(stmt);
|
|
ctx.InsertAfter(ip.first->Declaration()->statements, ip.second,
|
|
IfDiscardReturn(stmt));
|
|
return nullptr; // Don't replace current statement
|
|
}
|
|
|
|
// Replaces the input discard statement with either setting the module level
|
|
// discard bool ("tint_discard = true"), or calling the discard function
|
|
// ("tint_discard_func()"), followed by a default return statement.
|
|
//
|
|
// Replaces "discard;" with:
|
|
//
|
|
// tint_discard = true;
|
|
// return;
|
|
//
|
|
// Or if `stmt` is a entry point function, replaces with:
|
|
//
|
|
// tint_discard_func();
|
|
// return;
|
|
//
|
|
const ast::Statement* ReplaceDiscardStatement(
|
|
const ast::DiscardStatement* stmt) {
|
|
const ast::Statement* to_insert = nullptr;
|
|
if (IsInEntryPointFunc(stmt)) {
|
|
to_insert = CallDiscardFunc();
|
|
} else {
|
|
auto var_name = ModuleDiscardVarName();
|
|
to_insert = b.Assign(var_name, true);
|
|
}
|
|
|
|
auto ip = GetInsertionPoint(stmt);
|
|
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert);
|
|
return Return(stmt);
|
|
}
|
|
|
|
// Handle statement
|
|
const ast::Statement* Statement(const ast::Statement* stmt) {
|
|
return Switch(
|
|
stmt,
|
|
[&](const ast::DiscardStatement* s) -> const ast::Statement* {
|
|
return ReplaceDiscardStatement(s);
|
|
},
|
|
[&](const ast::AssignmentStatement* s) -> const ast::Statement* {
|
|
auto* sem_lhs = sem.Get(s->lhs);
|
|
auto* sem_rhs = sem.Get(s->rhs);
|
|
if (MayDiscard(sem_lhs)) {
|
|
if (MayDiscard(sem_rhs)) {
|
|
TINT_ICE(Transform, b.Diagnostics())
|
|
<< "Unexpected: both sides of assignment statement may "
|
|
"discard. Make sure transform::PromoteSideEffectsToDecl "
|
|
"was run first.";
|
|
}
|
|
return TryInsertAfter(s, sem_lhs);
|
|
} else if (MayDiscard(sem_rhs)) {
|
|
return TryInsertAfter(s, sem_rhs);
|
|
}
|
|
return nullptr;
|
|
},
|
|
[&](const ast::CallStatement* s) -> const ast::Statement* {
|
|
auto* sem_expr = sem.Get(s->expr);
|
|
if (!MayDiscard(sem_expr)) {
|
|
return nullptr;
|
|
}
|
|
return TryInsertAfter(s, sem_expr);
|
|
},
|
|
[&](const ast::ElseStatement* s) -> const ast::Statement* {
|
|
if (MayDiscard(sem.Get(s->condition))) {
|
|
TINT_ICE(Transform, b.Diagnostics())
|
|
<< "Unexpected ElseIf condition that may discard. Make sure "
|
|
"transform::PromoteSideEffectsToDecl was run first.";
|
|
}
|
|
return nullptr;
|
|
},
|
|
[&](const ast::ForLoopStatement* s) -> const ast::Statement* {
|
|
if (MayDiscard(sem.Get(s->condition))) {
|
|
TINT_ICE(Transform, b.Diagnostics())
|
|
<< "Unexpected ForLoopStatement condition that may discard. "
|
|
"Make sure transform::PromoteSideEffectsToDecl was run "
|
|
"first.";
|
|
}
|
|
return nullptr;
|
|
},
|
|
[&](const ast::IfStatement* s) -> const ast::Statement* {
|
|
auto* sem_expr = sem.Get(s->condition);
|
|
if (!MayDiscard(sem_expr)) {
|
|
return nullptr;
|
|
}
|
|
return HoistAndInsertBefore(s, sem_expr);
|
|
},
|
|
[&](const ast::ReturnStatement* s) -> const ast::Statement* {
|
|
auto* sem_expr = sem.Get(s->value);
|
|
if (!MayDiscard(sem_expr)) {
|
|
return nullptr;
|
|
}
|
|
return HoistAndInsertBefore(s, sem_expr);
|
|
},
|
|
[&](const ast::SwitchStatement* s) -> const ast::Statement* {
|
|
auto* sem_expr = sem.Get(s->condition);
|
|
if (!MayDiscard(sem_expr)) {
|
|
return nullptr;
|
|
}
|
|
return HoistAndInsertBefore(s, sem_expr);
|
|
},
|
|
[&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
|
|
auto* var = s->variable;
|
|
if (!var->constructor) {
|
|
return nullptr;
|
|
}
|
|
auto* sem_expr = sem.Get(var->constructor);
|
|
if (!MayDiscard(sem_expr)) {
|
|
return nullptr;
|
|
}
|
|
return TryInsertAfter(s, sem_expr);
|
|
});
|
|
}
|
|
|
|
public:
|
|
/// Constructor
|
|
/// @param ctx_in the context
|
|
explicit State(CloneContext& ctx_in)
|
|
: ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
|
|
|
|
/// Runs the transform
|
|
void Run() {
|
|
ctx.ReplaceAll(
|
|
[&](const ast::BlockStatement* block) -> const ast::Statement* {
|
|
// If this block is for an else-if statement, process the else-if now
|
|
// before processing its block statements.
|
|
// NOTE: we can't replace else statements at this point - this would
|
|
// need to be done when replacing the parent if-statement. However, in
|
|
// this transform, we don't ever expect to need to do this as else-ifs
|
|
// are converted to else { if } by PromoteSideEffectsToDecl, so this
|
|
// is only for validation.
|
|
if (auto* sem_else = ParentAs<sem::ElseStatement>(block)) {
|
|
if (auto* new_stmt = Statement(sem_else->Declaration())) {
|
|
TINT_ASSERT(Transform, new_stmt == nullptr);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// Iterate block statements and replace them as needed.
|
|
for (auto* stmt : block->statements) {
|
|
if (auto* new_stmt = Statement(stmt)) {
|
|
ctx.Replace(stmt, new_stmt);
|
|
}
|
|
|
|
// Handle for loops, as they are the only other AST node that
|
|
// contains statements outside of BlockStatements.
|
|
if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
|
|
if (auto* new_stmt = Statement(fl->initializer)) {
|
|
ctx.Replace(fl->initializer, new_stmt);
|
|
}
|
|
if (auto* new_stmt = Statement(fl->continuing)) {
|
|
// NOTE: Should never reach here as we cannot discard in a
|
|
// continuing block.
|
|
ctx.Replace(fl->continuing, new_stmt);
|
|
}
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
});
|
|
|
|
ctx.Clone();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
UnwindDiscardFunctions::UnwindDiscardFunctions() = default;
|
|
UnwindDiscardFunctions::~UnwindDiscardFunctions() = default;
|
|
|
|
void UnwindDiscardFunctions::Run(CloneContext& ctx,
|
|
const DataMap&,
|
|
DataMap&) const {
|
|
State state(ctx);
|
|
state.Run();
|
|
}
|
|
|
|
bool UnwindDiscardFunctions::ShouldRun(const Program* program,
|
|
const DataMap& /*data*/) const {
|
|
auto& sem = program->Sem();
|
|
for (auto* f : program->AST().Functions()) {
|
|
if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
} // namespace tint::transform
|