Rework all transforms to transform-on-copy

instead of transform-in-place.

This is a public API breaking change, so I've added the `DAWN_USE_NEW_TINT_TRANSFORM_API` define which is used by Dawn to know which API to use.

As we're going to have to go through the effort of an API breaking change, use this as an opportunity to rename Transformer to Transform, and remove 'Transform' from each of the transforms themselves (they're already in the transform namespace).

Bug: tint:390
Bug: tint:389
Change-Id: I1017507524b76bb4ffd26b95e550ef53ddc891c9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34800
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2020-12-04 09:06:09 +00:00 committed by Commit Bot service account
parent 0f37afb74e
commit 00b77a80ab
25 changed files with 1210 additions and 1317 deletions

View File

@ -413,16 +413,16 @@ source_set("libtint_core_src") {
"src/scope_stack.h", "src/scope_stack.h",
"src/source.cc", "src/source.cc",
"src/source.h", "src/source.h",
"src/transform/emit_vertex_point_size_transform.cc", "src/transform/emit_vertex_point_size.cc",
"src/transform/emit_vertex_point_size_transform.h", "src/transform/emit_vertex_point_size.h",
"src/transform/bound_array_accessors_transform.cc", "src/transform/bound_array_accessors.cc",
"src/transform/bound_array_accessors_transform.h", "src/transform/bound_array_accessors.h",
"src/transform/manager.cc", "src/transform/manager.cc",
"src/transform/manager.h", "src/transform/manager.h",
"src/transform/transformer.cc", "src/transform/transform.cc",
"src/transform/transformer.h", "src/transform/transform.h",
"src/transform/vertex_pulling_transform.cc", "src/transform/vertex_pulling.cc",
"src/transform/vertex_pulling_transform.h", "src/transform/vertex_pulling.h",
"src/type_determiner.cc", "src/type_determiner.cc",
"src/type_determiner.h", "src/type_determiner.h",
"src/validator/validator.cc", "src/validator/validator.cc",
@ -821,9 +821,9 @@ source_set("tint_unittests_core_src") {
"src/inspector/inspector_test.cc", "src/inspector/inspector_test.cc",
"src/namer_test.cc", "src/namer_test.cc",
"src/scope_stack_test.cc", "src/scope_stack_test.cc",
"src/transform/emit_vertex_point_size_transform_test.cc", "src/transform/emit_vertex_point_size_test.cc",
"src/transform/bound_array_accessors_transform_test.cc", "src/transform/bound_array_accessors_test.cc",
"src/transform/vertex_pulling_transform_test.cc", "src/transform/vertex_pulling_test.cc",
"src/type_determiner_test.cc", "src/type_determiner_test.cc",
"src/validator/validator_control_block_test.cc", "src/validator/validator_control_block_test.cc",
"src/validator/validator_function_test.cc", "src/validator/validator_function_test.cc",

View File

@ -25,10 +25,10 @@
#include "src/inspector/inspector.h" #include "src/inspector/inspector.h"
#include "src/namer.h" #include "src/namer.h"
#include "src/reader/reader.h" #include "src/reader/reader.h"
#include "src/transform/bound_array_accessors_transform.h" #include "src/transform/bound_array_accessors.h"
#include "src/transform/emit_vertex_point_size_transform.h" #include "src/transform/emit_vertex_point_size.h"
#include "src/transform/manager.h" #include "src/transform/manager.h"
#include "src/transform/vertex_pulling_transform.h" #include "src/transform/vertex_pulling.h"
#include "src/type_determiner.h" #include "src/type_determiner.h"
#include "src/validator/validator.h" #include "src/validator/validator.h"
#include "src/writer/writer.h" #include "src/writer/writer.h"

View File

@ -71,7 +71,7 @@ const char kUsage[] = R"(Usage: tint [options] <input-file>
-ep <compute|fragment|vertex> <name> -- Output single entry point -ep <compute|fragment|vertex> <name> -- Output single entry point
--output-file <name> -- Output file name. Use "-" for standard output --output-file <name> -- Output file name. Use "-" for standard output
-o <name> -- Output file name. Use "-" for standard output -o <name> -- Output file name. Use "-" for standard output
--transform <name list> -- Runs transformers, name list is comma separated --transform <name list> -- Runs transforms, name list is comma separated
Available transforms: Available transforms:
bound_array_accessors bound_array_accessors
emit_vertex_point_size emit_vertex_point_size
@ -515,22 +515,24 @@ int main(int argc, const char** argv) {
if (name == "bound_array_accessors") { if (name == "bound_array_accessors") {
transform_manager.append( transform_manager.append(
std::make_unique<tint::transform::BoundArrayAccessorsTransform>( std::make_unique<tint::transform::BoundArrayAccessors>());
&mod));
} else if (name == "emit_vertex_point_size") { } else if (name == "emit_vertex_point_size") {
transform_manager.append( transform_manager.append(
std::make_unique<tint::transform::EmitVertexPointSizeTransform>( std::make_unique<tint::transform::EmitVertexPointSize>());
&mod));
} else { } else {
std::cerr << "Unknown transform name: " << name << std::endl; std::cerr << "Unknown transform name: " << name << std::endl;
return 1; return 1;
} }
} }
if (!transform_manager.Run(&mod)) {
std::cerr << "Transformer: " << transform_manager.error() << std::endl; auto out = transform_manager.Run(&mod);
if (out.diagnostics.contains_errors()) {
diag_formatter.format(out.diagnostics, diag_printer.get());
return 1; return 1;
} }
mod = std::move(out.module);
std::unique_ptr<tint::writer::Writer> writer; std::unique_ptr<tint::writer::Writer> writer;
#if TINT_BUILD_SPV_WRITER #if TINT_BUILD_SPV_WRITER

View File

@ -234,16 +234,16 @@ set(TINT_LIB_SRCS
scope_stack.h scope_stack.h
source.cc source.cc
source.h source.h
transform/emit_vertex_point_size_transform.cc transform/emit_vertex_point_size.cc
transform/emit_vertex_point_size_transform.h transform/emit_vertex_point_size.h
transform/bound_array_accessors_transform.cc transform/bound_array_accessors.cc
transform/bound_array_accessors_transform.h transform/bound_array_accessors.h
transform/manager.cc transform/manager.cc
transform/manager.h transform/manager.h
transform/transformer.cc transform/transform.cc
transform/transformer.h transform/transform.h
transform/vertex_pulling_transform.cc transform/vertex_pulling.cc
transform/vertex_pulling_transform.h transform/vertex_pulling.h
type_determiner.cc type_determiner.cc
type_determiner.h type_determiner.h
validator/validator.cc validator/validator.cc
@ -431,9 +431,9 @@ set(TINT_TEST_SRCS
inspector/inspector_test.cc inspector/inspector_test.cc
namer_test.cc namer_test.cc
scope_stack_test.cc scope_stack_test.cc
transform/emit_vertex_point_size_transform_test.cc transform/emit_vertex_point_size_test.cc
transform/bound_array_accessors_transform_test.cc transform/bound_array_accessors_test.cc
transform/vertex_pulling_transform_test.cc transform/vertex_pulling_test.cc
type_determiner_test.cc type_determiner_test.cc
validator/validator_control_block_test.cc validator/validator_control_block_test.cc
validator/validator_function_test.cc validator/validator_function_test.cc

View File

@ -31,19 +31,21 @@ Module::~Module() = default;
Module Module::Clone() { Module Module::Clone() {
Module out; Module out;
CloneContext ctx(&out); CloneContext ctx(&out);
Clone(&ctx);
return out;
}
void Module::Clone(CloneContext* ctx) {
for (auto* ty : constructed_types_) { for (auto* ty : constructed_types_) {
out.constructed_types_.emplace_back(ctx.Clone(ty)); ctx->mod->constructed_types_.emplace_back(ctx->Clone(ty));
} }
for (auto* var : global_variables_) { for (auto* var : global_variables_) {
out.global_variables_.emplace_back(ctx.Clone(var)); ctx->mod->global_variables_.emplace_back(ctx->Clone(var));
} }
for (auto* func : functions_) { for (auto* func : functions_) {
out.functions_.emplace_back(ctx.Clone(func)); ctx->mod->functions_.emplace_back(ctx->Clone(func));
} }
return out;
} }
Function* Module::FindFunctionByName(const std::string& name) const { Function* Module::FindFunctionByName(const std::string& name) const {

View File

@ -54,6 +54,10 @@ class Module {
/// @return a deep copy of this module /// @return a deep copy of this module
Module Clone(); Module Clone();
/// Clone this module into `ctx->mod` using the provided CloneContext
/// @param ctx the clone context
void Clone(CloneContext* ctx);
/// Add a global variable to the module /// Add a global variable to the module
/// @param var the variable to add /// @param var the variable to add
void AddGlobalVariable(Variable* var) { global_variables_.push_back(var); } void AddGlobalVariable(Variable* var) { global_variables_.push_back(var); }

View File

@ -84,10 +84,18 @@ class List {
/// adds a diagnostic to the end of this list. /// adds a diagnostic to the end of this list.
/// @param diag the diagnostic to append to this list. /// @param diag the diagnostic to append to this list.
void add(Diagnostic&& diag) { void add(Diagnostic&& diag) {
entries_.emplace_back(std::move(diag));
if (diag.severity >= Severity::Error) { if (diag.severity >= Severity::Error) {
error_count_++; error_count_++;
} }
entries_.emplace_back(std::move(diag));
}
/// adds a list of diagnostics to the end of this list.
/// @param list the diagnostic to append to this list.
void add(const List& list) {
for (auto diag : list) {
add(std::move(diag));
}
} }
/// @returns true iff the diagnostic list contains errors diagnostics (or of /// @returns true iff the diagnostic list contains errors diagnostics (or of

View File

@ -0,0 +1,149 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/bound_array_accessors.h"
#include <memory>
#include <utility>
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/block_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/clone_context.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/uint_literal.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
namespace tint {
namespace transform {
BoundArrayAccessors::BoundArrayAccessors() = default;
BoundArrayAccessors::~BoundArrayAccessors() = default;
Transform::Output BoundArrayAccessors::Run(ast::Module* mod) {
Output out;
ast::CloneContext ctx(&out.module);
ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) {
return Transform(expr, &ctx, &out.diagnostics);
});
mod->Clone(&ctx);
return out;
}
ast::ArrayAccessorExpression* BoundArrayAccessors::Transform(
ast::ArrayAccessorExpression* expr,
ast::CloneContext* ctx,
diag::List* diags) {
auto* ret_type = expr->array()->result_type()->UnwrapAll();
if (!ret_type->Is<ast::type::Array>() && !ret_type->Is<ast::type::Matrix>() &&
!ret_type->Is<ast::type::Vector>()) {
return nullptr;
}
uint32_t size = 0;
if (ret_type->Is<ast::type::Vector>() || ret_type->Is<ast::type::Array>()) {
size = ret_type->Is<ast::type::Vector>()
? ret_type->As<ast::type::Vector>()->size()
: ret_type->As<ast::type::Array>()->size();
if (size == 0) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
err.message = "invalid 0 size for array or vector";
err.source = expr->source();
diags->add(std::move(err));
return nullptr;
}
} else {
// The row accessor would have been an embedded array accessor and already
// handled, so we just need to do columns here.
size = ret_type->As<ast::type::Matrix>()->columns();
}
ast::Expression* idx_expr = nullptr;
// Scalar constructor we can re-write the value to be within bounds.
if (auto* c = expr->idx_expr()->As<ast::ScalarConstructorExpression>()) {
auto* lit = c->literal();
if (auto* sint = lit->As<ast::SintLiteral>()) {
int32_t val = sint->value();
if (val < 0) {
val = 0;
} else if (val >= int32_t(size)) {
val = int32_t(size) - 1;
}
lit = ctx->mod->create<ast::SintLiteral>(ctx->Clone(sint->type()), val);
} else if (auto* uint = lit->As<ast::UintLiteral>()) {
uint32_t val = uint->value();
if (val >= size - 1) {
val = size - 1;
}
lit = ctx->mod->create<ast::UintLiteral>(ctx->Clone(uint->type()), val);
} else {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
err.message = "unknown scalar constructor type for accessor";
err.source = expr->source();
diags->add(std::move(err));
return nullptr;
}
idx_expr =
ctx->mod->create<ast::ScalarConstructorExpression>(c->source(), lit);
} else {
auto* u32 = ctx->mod->create<ast::type::U32>();
ast::ExpressionList cast_expr;
cast_expr.push_back(ctx->Clone(expr->idx_expr()));
ast::ExpressionList params;
params.push_back(
ctx->mod->create<ast::TypeConstructorExpression>(u32, cast_expr));
params.push_back(ctx->mod->create<ast::ScalarConstructorExpression>(
ctx->mod->create<ast::UintLiteral>(u32, size - 1)));
auto* call_expr = ctx->mod->create<ast::CallExpression>(
ctx->mod->create<ast::IdentifierExpression>("min"), std::move(params));
call_expr->set_result_type(u32);
idx_expr = call_expr;
}
auto* arr = ctx->Clone(expr->array());
return ctx->mod->create<ast::ArrayAccessorExpression>(arr, idx_expr);
}
} // namespace transform
} // namespace tint

View File

@ -0,0 +1,59 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_BOUND_ARRAY_ACCESSORS_H_
#define SRC_TRANSFORM_BOUND_ARRAY_ACCESSORS_H_
#include <string>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/expression.h"
#include "src/ast/module.h"
#include "src/ast/statement.h"
#include "src/context.h"
#include "src/scope_stack.h"
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// This transform is responsible for clamping all array accesses to be within
/// the bounds of the array. Any access before the start of the array will clamp
/// to zero and any access past the end of the array will clamp to
/// (array length - 1).
class BoundArrayAccessors : public Transform {
public:
/// Constructor
BoundArrayAccessors();
/// Destructor
~BoundArrayAccessors() override;
/// Runs the transform on `module`, returning the transformation result.
/// @note Users of Tint should register the transform with transform manager
/// and invoke its Run(), instead of directly calling the transform's Run().
/// Calling Run() directly does not perform module state cleanup operations.
/// @param module the source module to transform
/// @returns the transformation result
Output Run(ast::Module* module) override;
private:
ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr,
ast::CloneContext* ctx,
diag::List* diags);
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_BOUND_ARRAY_ACCESSORS_H_

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/transform/bound_array_accessors_transform.h" #include "src/transform/bound_array_accessors.h"
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -67,21 +67,24 @@ T* FindVariable(ast::Module* mod, std::string name) {
class BoundArrayAccessorsTest : public testing::Test { class BoundArrayAccessorsTest : public testing::Test {
public: public:
ast::Module Transform(ast::Module mod) { ast::Module Transform(ast::Module in) {
TypeDeterminer td(&mod); TypeDeterminer td(&in);
if (!td.Determine()) { if (!td.Determine()) {
error = "Type determination failed: " + td.error(); error = "Type determination failed: " + td.error();
return {}; return {};
} }
Manager manager; Manager manager;
manager.append(std::make_unique<BoundArrayAccessorsTransform>(&mod)); manager.append(std::make_unique<BoundArrayAccessors>());
if (!manager.Run(&mod)) { auto result = manager.Run(&in);
error = "manager().Run() errored:\n" + manager.error();
if (result.diagnostics.contains_errors()) {
error = "manager().Run() errored:\n" +
diag::Formatter().format(result.diagnostics);
return {}; return {};
} }
return mod; return std::move(result.module);
} }
std::string error; std::string error;

View File

@ -1,258 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/bound_array_accessors_transform.h"
#include <memory>
#include <utility>
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/block_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/uint_literal.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
namespace tint {
namespace transform {
BoundArrayAccessorsTransform::BoundArrayAccessorsTransform(ast::Module* mod)
: Transformer(mod) {}
BoundArrayAccessorsTransform::BoundArrayAccessorsTransform(Context*,
ast::Module* mod)
: BoundArrayAccessorsTransform(mod) {}
BoundArrayAccessorsTransform::~BoundArrayAccessorsTransform() = default;
bool BoundArrayAccessorsTransform::Run() {
// We skip over global variables as the constructor for a global must be a
// constant expression. There can't be any array accessors as per the current
// grammar.
for (auto* func : mod_->functions()) {
scope_stack_.push_scope();
if (!ProcessStatement(func->body())) {
return false;
}
scope_stack_.pop_scope();
}
return true;
}
bool BoundArrayAccessorsTransform::ProcessStatement(ast::Statement* stmt) {
if (auto* as = stmt->As<ast::AssignmentStatement>()) {
return ProcessExpression(as->lhs()) && ProcessExpression(as->rhs());
} else if (auto* block = stmt->As<ast::BlockStatement>()) {
for (auto* s : *block) {
if (!ProcessStatement(s)) {
return false;
}
}
} else if (stmt->Is<ast::BreakStatement>()) {
/* nop */
} else if (auto* call = stmt->As<ast::CallStatement>()) {
return ProcessExpression(call->expr());
} else if (auto* kase = stmt->As<ast::CaseStatement>()) {
return ProcessStatement(kase->body());
} else if (stmt->Is<ast::ContinueStatement>()) {
/* nop */
} else if (stmt->Is<ast::DiscardStatement>()) {
/* nop */
} else if (auto* e = stmt->As<ast::ElseStatement>()) {
return ProcessExpression(e->condition()) && ProcessStatement(e->body());
} else if (stmt->Is<ast::FallthroughStatement>()) {
/* nop */
} else if (auto* i = stmt->As<ast::IfStatement>()) {
if (!ProcessExpression(i->condition()) || !ProcessStatement(i->body())) {
return false;
}
for (auto* s : i->else_statements()) {
if (!ProcessStatement(s)) {
return false;
}
}
} else if (auto* l = stmt->As<ast::LoopStatement>()) {
if (l->has_continuing() && !ProcessStatement(l->continuing())) {
return false;
}
return ProcessStatement(l->body());
} else if (auto* r = stmt->As<ast::ReturnStatement>()) {
if (r->has_value()) {
return ProcessExpression(r->value());
}
} else if (auto* s = stmt->As<ast::SwitchStatement>()) {
if (!ProcessExpression(s->condition())) {
return false;
}
for (auto* c : s->body()) {
if (!ProcessStatement(c)) {
return false;
}
}
} else if (auto* vd = stmt->As<ast::VariableDeclStatement>()) {
auto* v = vd->variable();
if (v->has_constructor() && !ProcessExpression(v->constructor())) {
return false;
}
scope_stack_.set(v->name(), v);
} else {
error_ = "unknown statement in bound array accessors transform";
return false;
}
return true;
}
bool BoundArrayAccessorsTransform::ProcessExpression(ast::Expression* expr) {
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
return ProcessArrayAccessor(array);
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
return ProcessExpression(bitcast->expr());
} else if (auto* call = expr->As<ast::CallExpression>()) {
if (!ProcessExpression(call->func())) {
return false;
}
for (auto* e : call->params()) {
if (!ProcessExpression(e)) {
return false;
}
}
} else if (expr->Is<ast::IdentifierExpression>()) {
/* nop */
} else if (expr->Is<ast::ConstructorExpression>()) {
if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
for (auto* e : c->values()) {
if (!ProcessExpression(e)) {
return false;
}
}
}
} else if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
return ProcessExpression(m->structure()) && ProcessExpression(m->member());
} else if (auto* b = expr->As<ast::BinaryExpression>()) {
return ProcessExpression(b->lhs()) && ProcessExpression(b->rhs());
} else if (auto* u = expr->As<ast::UnaryOpExpression>()) {
return ProcessExpression(u->expr());
} else {
error_ = "unknown statement in bound array accessors transform";
return false;
}
return true;
}
bool BoundArrayAccessorsTransform::ProcessArrayAccessor(
ast::ArrayAccessorExpression* expr) {
if (!ProcessExpression(expr->array()) ||
!ProcessExpression(expr->idx_expr())) {
return false;
}
auto* ret_type = expr->array()->result_type()->UnwrapAll();
if (!ret_type->Is<ast::type::Array>() && !ret_type->Is<ast::type::Matrix>() &&
!ret_type->Is<ast::type::Vector>()) {
return true;
}
if (ret_type->Is<ast::type::Vector>() || ret_type->Is<ast::type::Array>()) {
uint32_t size = ret_type->Is<ast::type::Vector>()
? ret_type->As<ast::type::Vector>()->size()
: ret_type->As<ast::type::Array>()->size();
if (size == 0) {
error_ = "invalid 0 size for array or vector";
return false;
}
if (!ProcessAccessExpression(expr, size)) {
return false;
}
} else {
// The row accessor would have been an embedded array accessor and already
// handled, so we just need to do columns here.
uint32_t size = ret_type->As<ast::type::Matrix>()->columns();
if (!ProcessAccessExpression(expr, size)) {
return false;
}
}
return true;
}
bool BoundArrayAccessorsTransform::ProcessAccessExpression(
ast::ArrayAccessorExpression* expr,
uint32_t size) {
// Scalar constructor we can re-write the value to be within bounds.
if (auto* c = expr->idx_expr()->As<ast::ScalarConstructorExpression>()) {
auto* lit = c->literal();
if (auto* sint = lit->As<ast::SintLiteral>()) {
int32_t val = sint->value();
if (val < 0) {
val = 0;
} else if (val >= int32_t(size)) {
val = int32_t(size) - 1;
}
sint->set_value(val);
} else if (auto* uint = lit->As<ast::UintLiteral>()) {
uint32_t val = uint->value();
if (val >= size - 1) {
val = size - 1;
}
uint->set_value(val);
} else {
error_ = "unknown scalar constructor type for accessor";
return false;
}
} else {
auto* u32 = mod_->create<ast::type::U32>();
ast::ExpressionList cast_expr;
cast_expr.push_back(expr->idx_expr());
ast::ExpressionList params;
params.push_back(create<ast::TypeConstructorExpression>(u32, cast_expr));
params.push_back(create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(u32, size - 1)));
auto* call_expr = create<ast::CallExpression>(
create<ast::IdentifierExpression>("min"), std::move(params));
call_expr->set_result_type(u32);
expr->set_idx_expr(call_expr);
}
return true;
}
} // namespace transform
} // namespace tint

View File

@ -1,66 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_BOUND_ARRAY_ACCESSORS_TRANSFORM_H_
#define SRC_TRANSFORM_BOUND_ARRAY_ACCESSORS_TRANSFORM_H_
#include <string>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/expression.h"
#include "src/ast/module.h"
#include "src/ast/statement.h"
#include "src/context.h"
#include "src/scope_stack.h"
#include "src/transform/transformer.h"
namespace tint {
namespace transform {
/// This transformer is responsible for clamping all array accesses to be within
/// the bounds of the array. Any access before the start of the array will clamp
/// to zero and any access past the end of the array will clamp to
/// (array length - 1).
class BoundArrayAccessorsTransform : public Transformer {
public:
/// Constructor
/// @param mod the module transform
explicit BoundArrayAccessorsTransform(ast::Module* mod);
/// Constructor
/// DEPRECATED
/// @param ctx the Tint context object
/// @param mod the module transform
BoundArrayAccessorsTransform(Context* ctx, ast::Module* mod);
~BoundArrayAccessorsTransform() override;
/// Users of Tint should register the transform with transform manager and
/// invoke its Run(), instead of directly calling the transform's Run().
/// Calling Run() directly does not perform module state cleanup operations.
/// @returns true if the transformation was successful
bool Run() override;
private:
bool ProcessStatement(ast::Statement* stmt);
bool ProcessExpression(ast::Expression* expr);
bool ProcessArrayAccessor(ast::ArrayAccessorExpression* expr);
bool ProcessAccessExpression(ast::ArrayAccessorExpression* expr,
uint32_t size);
ScopeStack<ast::Variable*> scope_stack_;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_BOUND_ARRAY_ACCESSORS_TRANSFORM_H_

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/transform/emit_vertex_point_size_transform.h" #include "src/transform/emit_vertex_point_size.h"
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -34,44 +34,46 @@ const char kPointSizeVar[] = "tint_pointsize";
} // namespace } // namespace
EmitVertexPointSizeTransform::EmitVertexPointSizeTransform(ast::Module* mod) EmitVertexPointSize::EmitVertexPointSize() = default;
: Transformer(mod) {} EmitVertexPointSize::~EmitVertexPointSize() = default;
EmitVertexPointSizeTransform::~EmitVertexPointSizeTransform() = default; Transform::Output EmitVertexPointSize::Run(ast::Module* in) {
Output out;
out.module = in->Clone();
auto* mod = &out.module;
bool EmitVertexPointSizeTransform::Run() { if (!mod->HasStage(ast::PipelineStage::kVertex)) {
if (!mod_->HasStage(ast::PipelineStage::kVertex)) {
// If the module doesn't have any vertex stages, then there's nothing to do. // If the module doesn't have any vertex stages, then there's nothing to do.
return true; return out;
} }
auto* f32 = mod_->create<ast::type::F32>(); auto* f32 = mod->create<ast::type::F32>();
// Declare the pointsize builtin output variable. // Declare the pointsize builtin output variable.
auto* pointsize_var = auto* pointsize_var =
mod_->create<ast::DecoratedVariable>(mod_->create<ast::Variable>( mod->create<ast::DecoratedVariable>(mod->create<ast::Variable>(
kPointSizeVar, ast::StorageClass::kOutput, f32)); kPointSizeVar, ast::StorageClass::kOutput, f32));
pointsize_var->set_decorations({ pointsize_var->set_decorations({
mod_->create<ast::BuiltinDecoration>(ast::Builtin::kPointSize, Source{}), mod->create<ast::BuiltinDecoration>(ast::Builtin::kPointSize, Source{}),
}); });
mod_->AddGlobalVariable(pointsize_var); mod->AddGlobalVariable(pointsize_var);
// Build the AST expression & statement for assigning pointsize one. // Build the AST expression & statement for assigning pointsize one.
auto* one = mod_->create<ast::ScalarConstructorExpression>( auto* one = mod->create<ast::ScalarConstructorExpression>(
mod_->create<ast::FloatLiteral>(f32, 1.0f)); mod->create<ast::FloatLiteral>(f32, 1.0f));
auto* pointsize_ident = auto* pointsize_ident =
mod_->create<ast::IdentifierExpression>(Source{}, kPointSizeVar); mod->create<ast::IdentifierExpression>(Source{}, kPointSizeVar);
auto* pointsize_assign = auto* pointsize_assign =
mod_->create<ast::AssignmentStatement>(pointsize_ident, one); mod->create<ast::AssignmentStatement>(pointsize_ident, one);
// Add the pointsize assignment statement to the front of all vertex stages. // Add the pointsize assignment statement to the front of all vertex stages.
for (auto* func : mod_->functions()) { for (auto* func : mod->functions()) {
if (func->pipeline_stage() == ast::PipelineStage::kVertex) { if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
func->body()->insert(0, pointsize_assign); func->body()->insert(0, pointsize_assign);
} }
} }
return true; return out;
} }
} // namespace transform } // namespace transform

View File

@ -12,34 +12,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef SRC_TRANSFORM_EMIT_VERTEX_POINT_SIZE_TRANSFORM_H_ #ifndef SRC_TRANSFORM_EMIT_VERTEX_POINT_SIZE_H_
#define SRC_TRANSFORM_EMIT_VERTEX_POINT_SIZE_TRANSFORM_H_ #define SRC_TRANSFORM_EMIT_VERTEX_POINT_SIZE_H_
#include "src/transform/transformer.h" #include "src/transform/transform.h"
namespace tint { namespace tint {
namespace transform { namespace transform {
/// EmitVertexPointSizeTransform is a Transformer that adds a PointSize builtin /// EmitVertexPointSize is a Transform that adds a PointSize builtin global
/// global output variable to the module which is assigned 1.0 as the new first /// output variable to the module which is assigned 1.0 as the new first
/// statement for all vertex stage entry points. /// statement for all vertex stage entry points.
/// If the module does not contain a vertex pipeline stage entry point then then /// If the module does not contain a vertex pipeline stage entry point then then
/// this transformer is a no-op. /// this transform is a no-op.
class EmitVertexPointSizeTransform : public Transformer { class EmitVertexPointSize : public Transform {
public: public:
/// Constructor /// Constructor
/// @param mod the module transform EmitVertexPointSize();
explicit EmitVertexPointSizeTransform(ast::Module* mod); /// Destructor
~EmitVertexPointSizeTransform() override; ~EmitVertexPointSize() override;
/// Users of Tint should register the transform with transform manager and /// Runs the transform on `module`, returning the transformation result.
/// invoke its Run(), instead of directly calling the transform's Run(). /// @note Users of Tint should register the transform with transform manager
/// and invoke its Run(), instead of directly calling the transform's Run().
/// Calling Run() directly does not perform module state cleanup operations. /// Calling Run() directly does not perform module state cleanup operations.
/// @returns true if the transformation was successful /// @param module the source module to transform
bool Run() override; /// @returns the transformation result
Output Run(ast::Module* module) override;
}; };
} // namespace transform } // namespace transform
} // namespace tint } // namespace tint
#endif // SRC_TRANSFORM_EMIT_VERTEX_POINT_SIZE_TRANSFORM_H_ #endif // SRC_TRANSFORM_EMIT_VERTEX_POINT_SIZE_H_

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/transform/emit_vertex_point_size_transform.h" #include "src/transform/emit_vertex_point_size.h"
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -21,7 +21,6 @@
#include "src/ast/builder.h" #include "src/ast/builder.h"
#include "src/ast/stage_decoration.h" #include "src/ast/stage_decoration.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/diagnostic/diagnostic.h"
#include "src/diagnostic/formatter.h" #include "src/diagnostic/formatter.h"
#include "src/transform/manager.h" #include "src/transform/manager.h"
@ -29,26 +28,12 @@ namespace tint {
namespace transform { namespace transform {
namespace { namespace {
class EmitVertexPointSizeTransformTest : public testing::Test { class EmitVertexPointSizeTest : public testing::Test {
public: public:
struct Output { Transform::Output Transform(ast::Module in) {
ast::Module module;
diag::List diagnostics;
};
Output Transform(ast::Module mod) {
Manager manager; Manager manager;
manager.append(std::make_unique<EmitVertexPointSizeTransform>(&mod)); manager.append(std::make_unique<EmitVertexPointSize>());
manager.Run(&mod); return manager.Run(&in);
Output out;
out.module = std::move(mod);
auto err = manager.error();
if (!err.empty()) {
diag::Diagnostic diag;
diag.message = err;
diag.severity = diag::Severity::Error;
out.diagnostics.add(std::move(diag));
}
return out;
} }
}; };
@ -64,7 +49,7 @@ struct ModuleBuilder : public ast::BuilderWithModule {
virtual void Build() = 0; virtual void Build() = 0;
}; };
TEST_F(EmitVertexPointSizeTransformTest, VertexStageBasic) { TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
struct Builder : ModuleBuilder { struct Builder : ModuleBuilder {
void Build() override { void Build() override {
auto* block = create<ast::BlockStatement>(Source{}); auto* block = create<ast::BlockStatement>(Source{});
@ -131,7 +116,7 @@ TEST_F(EmitVertexPointSizeTransformTest, VertexStageBasic) {
EXPECT_EQ(expected, result.module.to_str()); EXPECT_EQ(expected, result.module.to_str());
} }
TEST_F(EmitVertexPointSizeTransformTest, VertexStageEmpty) { TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
struct Builder : ModuleBuilder { struct Builder : ModuleBuilder {
void Build() override { void Build() override {
mod->AddFunction( mod->AddFunction(
@ -186,7 +171,7 @@ TEST_F(EmitVertexPointSizeTransformTest, VertexStageEmpty) {
EXPECT_EQ(expected, result.module.to_str()); EXPECT_EQ(expected, result.module.to_str());
} }
TEST_F(EmitVertexPointSizeTransformTest, NonVertexStage) { TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
struct Builder : ModuleBuilder { struct Builder : ModuleBuilder {
void Build() override { void Build() override {
auto* fragment_entry = auto* fragment_entry =

View File

@ -20,36 +20,29 @@ namespace tint {
namespace transform { namespace transform {
Manager::Manager() = default; Manager::Manager() = default;
Manager::Manager(Context*, ast::Module* module) : module_(module) {}
Manager::~Manager() = default; Manager::~Manager() = default;
bool Manager::Run() { Transform::Output Manager::Run(ast::Module* module) {
return Run(module_); Output out;
}
bool Manager::Run(ast::Module* module) {
error_ = "";
for (auto& transform : transforms_) { for (auto& transform : transforms_) {
if (!transform->Run()) { auto res = transform->Run(module);
error_ = transform->error(); out.module = std::move(res.module);
return false; out.diagnostics.add(std::move(res.diagnostics));
if (out.diagnostics.contains_errors()) {
return out;
} }
module = &out.module;
} }
if (module != nullptr) { TypeDeterminer td(module);
// The transformed have potentially inserted nodes into the AST, so the type if (!td.Determine()) {
// determinater needs to be run. diag::Diagnostic err;
TypeDeterminer td(module); err.severity = diag::Severity::Error;
if (!td.Determine()) { err.message = td.error();
error_ = td.error(); out.diagnostics.add(std::move(err));
return false;
}
} }
return true; return out;
} }
} // namespace transform } // namespace transform

View File

@ -20,8 +20,13 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "src/context.h" #include "src/diagnostic/diagnostic.h"
#include "src/transform/transformer.h" #include "src/transform/transform.h"
// A define used by Dawn to atomically switch to the new tint::transform API
// when the API breaking change lands.
// TODO(bclayton) - Remove once migration is complete
#define DAWN_USE_NEW_TINT_TRANSFORM_API 1
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -29,40 +34,25 @@ namespace transform {
/// Manager for the provided passes. The passes will be execute in the /// Manager for the provided passes. The passes will be execute in the
/// appended order. If any pass fails the manager will return immediately and /// appended order. If any pass fails the manager will return immediately and
/// the error can be retrieved with the error() method. /// the error can be retrieved with the error() method.
class Manager { class Manager : public Transform {
public: public:
/// Constructor /// Constructor
Manager(); Manager();
/// Constructor ~Manager() override;
/// DEPRECATED
/// @param context the tint context
/// @param module the module to transform
Manager(Context* context, ast::Module* module);
~Manager();
/// Add pass to the manager /// Add pass to the manager
/// @param transform the transform to append /// @param transform the transform to append
void append(std::unique_ptr<Transformer> transform) { void append(std::unique_ptr<Transform> transform) {
transforms_.push_back(std::move(transform)); transforms_.push_back(std::move(transform));
} }
/// Runs the transforms /// Runs the transforms on `module`, returning the transformation result.
/// @param module the module to run the transforms on /// @param module the source module to transform
/// @returns true on success; false otherwise /// @returns the transformed module and diagnostics
bool Run(ast::Module* module); Output Run(ast::Module* module) override;
/// Runs the transforms
/// DEPRECATED
/// @returns true on success; false otherwise
bool Run();
/// @returns the error, or blank if none set
std::string error() const { return error_; }
private: private:
std::vector<std::unique_ptr<Transformer>> transforms_; std::vector<std::unique_ptr<Transform>> transforms_;
ast::Module* module_ = nullptr;
std::string error_;
}; };
} // namespace transform } // namespace transform

View File

@ -12,14 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/transform/transformer.h" #include "src/transform/transform.h"
namespace tint { namespace tint {
namespace transform { namespace transform {
Transformer::Transformer(ast::Module* mod) : mod_(mod) {} Transform::Transform() = default;
Transform::~Transform() = default;
Transformer::~Transformer() = default;
} // namespace transform } // namespace transform
} // namespace tint } // namespace tint

57
src/transform/transform.h Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_TRANSFORM_H_
#define SRC_TRANSFORM_TRANSFORM_H_
#include <memory>
#include <string>
#include <utility>
#include "src/ast/module.h"
#include "src/context.h"
#include "src/diagnostic/diagnostic.h"
namespace tint {
namespace transform {
/// Interface for ast::Module transforms
class Transform {
public:
/// Constructor
Transform();
/// Destructor
virtual ~Transform();
/// The return type of Run()
struct Output {
/// The transformed module. May be empty on error.
ast::Module module;
/// Diagnostics raised while running the Transform.
diag::List diagnostics;
};
/// Runs the transform on `module`, returning the transformation result.
/// @note Users of Tint should register the transform with transform manager
/// and invoke its Run(), instead of directly calling the transform's Run().
/// Calling Run() directly does not perform module state cleanup operations.
/// @param module the source module to transform
/// @returns the transformation result
virtual Output Run(ast::Module* module) = 0;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_TRANSFORM_H_

View File

@ -1,64 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_TRANSFORMER_H_
#define SRC_TRANSFORM_TRANSFORMER_H_
#include <memory>
#include <string>
#include <utility>
#include "src/ast/module.h"
#include "src/context.h"
namespace tint {
namespace transform {
/// Interface class for the transformers
class Transformer {
public:
/// Constructor
/// @param mod the module to transform
explicit Transformer(ast::Module* mod);
virtual ~Transformer();
/// Users of Tint should register the transform with transform manager and
/// invoke its Run(), instead of directly calling the transform's Run().
/// Calling Run() directly does not perform module state cleanup operations.
/// @returns true if the transformation was successful
virtual bool Run() = 0;
/// @returns error messages
const std::string& error() { return error_; }
protected:
/// Creates a new `ast::Node` owned by the Module. When the Module is
/// destructed, the `ast::Node` will also be destructed.
/// @param args the arguments to pass to the type constructor
/// @returns the node pointer
template <typename T, typename... ARGS>
T* create(ARGS&&... args) {
return mod_->create<T>(std::forward<ARGS>(args)...);
}
/// The module
ast::Module* mod_ = nullptr;
/// Any error messages, or blank if no error
std::string error_;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_TRANSFORMER_H_

View File

@ -0,0 +1,473 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/vertex_pulling.h"
#include <utility>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/stride_decoration.h"
#include "src/ast/struct.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/struct_decoration.h"
#include "src/ast/struct_member.h"
#include "src/ast/struct_member_offset_decoration.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/uint_literal.h"
#include "src/ast/variable_decl_statement.h"
namespace tint {
namespace transform {
namespace {
static const char kVertexBufferNamePrefix[] = "_tint_pulling_vertex_buffer_";
static const char kStructBufferName[] = "_tint_vertex_data";
static const char kStructName[] = "TintVertexData";
static const char kPullingPosVarName[] = "_tint_pulling_pos";
static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index";
static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index";
} // namespace
VertexPulling::VertexPulling() = default;
VertexPulling::~VertexPulling() = default;
void VertexPulling::SetVertexState(const VertexStateDescriptor& vertex_state) {
cfg.vertex_state = vertex_state;
cfg.vertex_state_set = true;
}
void VertexPulling::SetEntryPoint(std::string entry_point) {
cfg.entry_point_name = std::move(entry_point);
}
void VertexPulling::SetPullingBufferBindingSet(uint32_t number) {
cfg.pulling_set = number;
}
Transform::Output VertexPulling::Run(ast::Module* in) {
Output out;
out.module = in->Clone();
ast::Module* mod = &out.module;
// Check SetVertexState was called
if (!cfg.vertex_state_set) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
err.message = "SetVertexState not called";
out.diagnostics.add(std::move(err));
return out;
}
// Find entry point
auto* func = mod->FindFunctionByNameAndStage(cfg.entry_point_name,
ast::PipelineStage::kVertex);
if (func == nullptr) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
err.message = "Vertex stage entry point not found";
out.diagnostics.add(std::move(err));
return out;
}
// Save the vertex function
auto* vertex_func = mod->FindFunctionByName(func->name());
// TODO(idanr): Need to check shader locations in descriptor cover all
// attributes
// TODO(idanr): Make sure we covered all error cases, to guarantee the
// following stages will pass
State state{mod, cfg};
state.FindOrInsertVertexIndexIfUsed();
state.FindOrInsertInstanceIndexIfUsed();
state.ConvertVertexInputVariablesToPrivate();
state.AddVertexStorageBuffers();
state.AddVertexPullingPreamble(vertex_func);
return out;
}
VertexPulling::Config::Config() = default;
VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default;
VertexPulling::State::State(ast::Module* m, const Config& c) : mod(m), cfg(c) {}
VertexPulling::State::~State() = default;
std::string VertexPulling::State::GetVertexBufferName(uint32_t index) {
return kVertexBufferNamePrefix + std::to_string(index);
}
void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
bool uses_vertex_step_mode = false;
for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
if (buffer_layout.step_mode == InputStepMode::kVertex) {
uses_vertex_step_mode = true;
break;
}
}
if (!uses_vertex_step_mode) {
return;
}
// Look for an existing vertex index builtin
for (auto* v : mod->global_variables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
if (auto* decorated = v->As<ast::DecoratedVariable>()) {
for (auto* d : decorated->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kVertexIdx) {
vertex_index_name = v->name();
return;
}
}
}
}
}
// We didn't find a vertex index builtin, so create one
vertex_index_name = kDefaultVertexIndexName;
auto* var = mod->create<ast::DecoratedVariable>(mod->create<ast::Variable>(
vertex_index_name, ast::StorageClass::kInput, GetI32Type()));
ast::VariableDecorationList decorations;
decorations.push_back(
mod->create<ast::BuiltinDecoration>(ast::Builtin::kVertexIdx, Source{}));
var->set_decorations(std::move(decorations));
mod->AddGlobalVariable(var);
}
void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
bool uses_instance_step_mode = false;
for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
if (buffer_layout.step_mode == InputStepMode::kInstance) {
uses_instance_step_mode = true;
break;
}
}
if (!uses_instance_step_mode) {
return;
}
// Look for an existing instance index builtin
for (auto* v : mod->global_variables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
if (auto* decorated = v->As<ast::DecoratedVariable>()) {
for (auto* d : decorated->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kInstanceIdx) {
instance_index_name = v->name();
return;
}
}
}
}
}
// We didn't find an instance index builtin, so create one
instance_index_name = kDefaultInstanceIndexName;
auto* var = mod->create<ast::DecoratedVariable>(mod->create<ast::Variable>(
instance_index_name, ast::StorageClass::kInput, GetI32Type()));
ast::VariableDecorationList decorations;
decorations.push_back(mod->create<ast::BuiltinDecoration>(
ast::Builtin::kInstanceIdx, Source{}));
var->set_decorations(std::move(decorations));
mod->AddGlobalVariable(var);
}
void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
for (auto*& v : mod->global_variables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
if (auto* decorated = v->As<ast::DecoratedVariable>()) {
for (auto* d : decorated->decorations()) {
if (auto* l = d->As<ast::LocationDecoration>()) {
uint32_t location = l->value();
// This is where the replacement happens. Expressions use identifier
// strings instead of pointers, so we don't need to update any other
// place in the AST.
v = mod->create<ast::Variable>(v->name(), ast::StorageClass::kPrivate,
v->type());
location_to_var[location] = v;
break;
}
}
}
}
}
void VertexPulling::State::AddVertexStorageBuffers() {
// TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
// The array inside the struct definition
auto internal_array = std::make_unique<ast::type::Array>(GetU32Type());
ast::ArrayDecorationList ary_decos;
ary_decos.push_back(mod->create<ast::StrideDecoration>(4u, Source{}));
internal_array->set_decorations(std::move(ary_decos));
auto* internal_array_type = mod->unique_type(std::move(internal_array));
// Creating the struct type
ast::StructMemberList members;
ast::StructMemberDecorationList member_dec;
member_dec.push_back(
mod->create<ast::StructMemberOffsetDecoration>(0u, Source{}));
members.push_back(mod->create<ast::StructMember>(
kStructBufferName, internal_array_type, std::move(member_dec)));
ast::StructDecorationList decos;
decos.push_back(mod->create<ast::StructBlockDecoration>(Source{}));
auto* struct_type = mod->create<ast::type::Struct>(
kStructName,
mod->create<ast::Struct>(std::move(decos), std::move(members)));
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type
auto* var = mod->create<ast::DecoratedVariable>(mod->create<ast::Variable>(
GetVertexBufferName(i), ast::StorageClass::kStorageBuffer,
struct_type));
// Add decorations
ast::VariableDecorationList decorations;
decorations.push_back(mod->create<ast::BindingDecoration>(i, Source{}));
decorations.push_back(
mod->create<ast::SetDecoration>(cfg.pulling_set, Source{}));
var->set_decorations(std::move(decorations));
mod->AddGlobalVariable(var);
}
mod->AddConstructedType(struct_type);
}
void VertexPulling::State::AddVertexPullingPreamble(
ast::Function* vertex_func) {
// Assign by looking at the vertex descriptor to find attributes with matching
// location.
// A block statement allowing us to use append instead of insert
auto* block = mod->create<ast::BlockStatement>();
// Declare the |kPullingPosVarName| variable in the shader
auto* pos_declaration =
mod->create<ast::VariableDeclStatement>(mod->create<ast::Variable>(
kPullingPosVarName, ast::StorageClass::kFunction, GetI32Type()));
// |kPullingPosVarName| refers to the byte location of the current read. We
// declare a variable in the shader to avoid having to reuse Expression
// objects.
block->append(pos_declaration);
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i];
for (const VertexAttributeDescriptor& attribute_desc :
buffer_layout.attributes) {
auto it = location_to_var.find(attribute_desc.shader_location);
if (it == location_to_var.end()) {
continue;
}
auto* v = it->second;
// Identifier to index by
auto* index_identifier = mod->create<ast::IdentifierExpression>(
buffer_layout.step_mode == InputStepMode::kVertex
? vertex_index_name
: instance_index_name);
// An expression for the start of the read in the buffer in bytes
auto* pos_value = mod->create<ast::BinaryExpression>(
ast::BinaryOp::kAdd,
mod->create<ast::BinaryExpression>(
ast::BinaryOp::kMultiply, index_identifier,
GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
GenUint(static_cast<uint32_t>(attribute_desc.offset)));
// Update position of the read
auto* set_pos_expr = mod->create<ast::AssignmentStatement>(
CreatePullingPositionIdent(), pos_value);
block->append(set_pos_expr);
block->append(mod->create<ast::AssignmentStatement>(
mod->create<ast::IdentifierExpression>(v->name()),
AccessByFormat(i, attribute_desc.format)));
}
}
vertex_func->body()->insert(0, block);
}
ast::Expression* VertexPulling::State::GenUint(uint32_t value) {
return mod->create<ast::ScalarConstructorExpression>(
mod->create<ast::UintLiteral>(GetU32Type(), value));
}
ast::Expression* VertexPulling::State::CreatePullingPositionIdent() {
return mod->create<ast::IdentifierExpression>(kPullingPosVarName);
}
ast::Expression* VertexPulling::State::AccessByFormat(uint32_t buffer,
VertexFormat format) {
// TODO(idanr): this doesn't account for the format of the attribute in the
// shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
// right now, we would try to assign a vec4<f32> to this attribute, but we
// really need to assign a vec4<u32> by casting.
// We could split this function to first do memory accesses and unpacking into
// int/uint/float1-4/etc, then convert that variable to a var<in> with the
// conversion defined in the WebGPU spec.
switch (format) {
case VertexFormat::kU32:
return AccessU32(buffer, CreatePullingPositionIdent());
case VertexFormat::kI32:
return AccessI32(buffer, CreatePullingPositionIdent());
case VertexFormat::kF32:
return AccessF32(buffer, CreatePullingPositionIdent());
case VertexFormat::kVec2F32:
return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 2);
case VertexFormat::kVec3F32:
return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 3);
case VertexFormat::kVec4F32:
return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 4);
default:
return nullptr;
}
}
ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
ast::Expression* pos) {
// Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
// has byte offsets for each attribute, and we will convert it to u32 indexes
// by dividing. Then, that element is going to be read, and if needed,
// unpacked into an appropriate variable. All reads should end up here as a
// base case.
return mod->create<ast::ArrayAccessorExpression>(
mod->create<ast::MemberAccessorExpression>(
mod->create<ast::IdentifierExpression>(GetVertexBufferName(buffer)),
mod->create<ast::IdentifierExpression>(kStructBufferName)),
mod->create<ast::BinaryExpression>(ast::BinaryOp::kDivide, pos,
GenUint(4)));
}
ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
ast::Expression* pos) {
// as<T> reinterprets bits
return mod->create<ast::BitcastExpression>(GetI32Type(),
AccessU32(buffer, pos));
}
ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
ast::Expression* pos) {
// as<T> reinterprets bits
return mod->create<ast::BitcastExpression>(GetF32Type(),
AccessU32(buffer, pos));
}
ast::Expression* VertexPulling::State::AccessPrimitive(uint32_t buffer,
ast::Expression* pos,
VertexFormat format) {
// This function uses a position expression to read, rather than using the
// position variable. This allows us to read from offset positions relative to
// |kPullingPosVarName|. We can't call AccessByFormat because it reads only
// from the position variable.
switch (format) {
case VertexFormat::kU32:
return AccessU32(buffer, pos);
case VertexFormat::kI32:
return AccessI32(buffer, pos);
case VertexFormat::kF32:
return AccessF32(buffer, pos);
default:
return nullptr;
}
}
ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer,
uint32_t element_stride,
ast::type::Type* base_type,
VertexFormat base_format,
uint32_t count) {
ast::ExpressionList expr_list;
for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component
auto* cur_pos = mod->create<ast::BinaryExpression>(
ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
GenUint(element_stride * i));
expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
}
return mod->create<ast::TypeConstructorExpression>(
mod->create<ast::type::Vector>(base_type, count), std::move(expr_list));
}
ast::type::Type* VertexPulling::State::GetU32Type() {
return mod->create<ast::type::U32>();
}
ast::type::Type* VertexPulling::State::GetI32Type() {
return mod->create<ast::type::I32>();
}
ast::type::Type* VertexPulling::State::GetF32Type() {
return mod->create<ast::type::F32>();
}
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
uint64_t in_array_stride,
InputStepMode in_step_mode,
std::vector<VertexAttributeDescriptor> in_attributes)
: array_stride(in_array_stride),
step_mode(in_step_mode),
attributes(std::move(in_attributes)) {}
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
const VertexBufferLayoutDescriptor& other) = default;
VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=(
const VertexBufferLayoutDescriptor& other) = default;
VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
} // namespace transform
} // namespace tint

View File

@ -0,0 +1,272 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_VERTEX_PULLING_H_
#define SRC_TRANSFORM_VERTEX_PULLING_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "src/ast/expression.h"
#include "src/ast/function.h"
#include "src/ast/module.h"
#include "src/ast/statement.h"
#include "src/ast/variable.h"
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// Describes the format of data in a vertex buffer
enum class VertexFormat {
kVec2U8,
kVec4U8,
kVec2I8,
kVec4I8,
kVec2U8Norm,
kVec4U8Norm,
kVec2I8Norm,
kVec4I8Norm,
kVec2U16,
kVec4U16,
kVec2I16,
kVec4I16,
kVec2U16Norm,
kVec4U16Norm,
kVec2I16Norm,
kVec4I16Norm,
kVec2F16,
kVec4F16,
kF32,
kVec2F32,
kVec3F32,
kVec4F32,
kU32,
kVec2U32,
kVec3U32,
kVec4U32,
kI32,
kVec2I32,
kVec3I32,
kVec4I32
};
/// Describes if a vertex attributes increments with vertex index or instance
/// index
enum class InputStepMode { kVertex, kInstance };
/// Describes a vertex attribute within a buffer
struct VertexAttributeDescriptor {
/// The format of the attribute
VertexFormat format;
/// The byte offset of the attribute in the buffer
uint64_t offset;
/// The shader location used for the attribute
uint32_t shader_location;
};
/// Describes a buffer containing multiple vertex attributes
struct VertexBufferLayoutDescriptor {
/// Constructor
VertexBufferLayoutDescriptor();
/// Constructor
/// @param in_array_stride the array stride of the in buffer
/// @param in_step_mode the step mode of the in buffer
/// @param in_attributes the in attributes
VertexBufferLayoutDescriptor(
uint64_t in_array_stride,
InputStepMode in_step_mode,
std::vector<VertexAttributeDescriptor> in_attributes);
/// Copy constructor
/// @param other the struct to copy
VertexBufferLayoutDescriptor(const VertexBufferLayoutDescriptor& other);
/// Assignment operator
/// @param other the struct to copy
/// @returns this struct
VertexBufferLayoutDescriptor& operator=(
const VertexBufferLayoutDescriptor& other);
~VertexBufferLayoutDescriptor();
/// The array stride used in the in buffer
uint64_t array_stride = 0u;
/// The input step mode used
InputStepMode step_mode = InputStepMode::kVertex;
/// The vertex attributes
std::vector<VertexAttributeDescriptor> attributes;
};
/// Describes vertex state, which consists of many buffers containing vertex
/// attributes
using VertexStateDescriptor = std::vector<VertexBufferLayoutDescriptor>;
/// Converts a module to use vertex pulling
///
/// Variables which accept vertex input are var<in> with a location decoration.
/// This transform will convert those to be assigned from storage buffers
/// instead. The intention is to allow vertex input to rely on a storage buffer
/// clamping pass for out of bounds reads. We bind the storage buffers as arrays
/// of u32, so any read to byte position `p` will actually need to read position
/// `p / 4`, since `sizeof(u32) == 4`.
///
/// `VertexFormat` represents the input type of the attribute. This isn't
/// related to the type of the variable in the shader. For example,
/// `VertexFormat::kVec2F16` tells us that the buffer will contain `f16`
/// elements, to be read as vec2. In the shader, a user would make a `vec2<f32>`
/// to be able to use them. The conversion between `f16` and `f32` will need to
/// be handled by us (using unpack functions).
///
/// To be clear, there won't be types such as `f16` or `u8` anywhere in WGSL
/// code, but these are types that the data may arrive as. We need to convert
/// these smaller types into the base types such as `f32` and `u32` for the
/// shader to use.
class VertexPulling : public Transform {
public:
/// Constructor
VertexPulling();
/// Destructor
~VertexPulling() override;
/// Sets the vertex state descriptor, containing info about attributes
/// @param vertex_state the vertex state descriptor
void SetVertexState(const VertexStateDescriptor& vertex_state);
/// Sets the entry point to add assignments into
/// @param entry_point the vertex stage entry point
void SetEntryPoint(std::string entry_point);
/// Sets the "set" we will put all our vertex buffers into (as storage
/// buffers)
/// @param number the set number we will use
void SetPullingBufferBindingSet(uint32_t number);
/// Runs the transform on `module`, returning the transformation result.
/// @note Users of Tint should register the transform with transform manager
/// and invoke its Run(), instead of directly calling the transform's Run().
/// Calling Run() directly does not perform module state cleanup operations.
/// @param module the source module to transform
/// @returns the transformation result
Output Run(ast::Module* module) override;
private:
struct Config {
Config();
Config(const Config&);
~Config();
std::string entry_point_name;
VertexStateDescriptor vertex_state;
bool vertex_state_set = false;
// Default to 4 as it is past the limits of user-accessible sets
uint32_t pulling_set = 4u;
};
Config cfg;
struct State {
State(ast::Module* m, const Config& c);
~State();
/// Generate the vertex buffer binding name
/// @param index index to append to buffer name
std::string GetVertexBufferName(uint32_t index);
/// Inserts vertex_idx binding, or finds the existing one
void FindOrInsertVertexIndexIfUsed();
/// Inserts instance_idx binding, or finds the existing one
void FindOrInsertInstanceIndexIfUsed();
/// Converts var<in> with a location decoration to var<private>
void ConvertVertexInputVariablesToPrivate();
/// Adds storage buffer decorated variables for the vertex buffers
void AddVertexStorageBuffers();
/// Adds assignment to the variables from the buffers
void AddVertexPullingPreamble(ast::Function* vertex_func);
/// Generates an expression holding a constant uint
/// @param value uint value
ast::Expression* GenUint(uint32_t value);
/// Generates an expression to read the shader value `kPullingPosVarName`
ast::Expression* CreatePullingPositionIdent();
/// Generates an expression reading from a buffer a specific format.
/// This reads the value wherever `kPullingPosVarName` points to at the time
/// of the read.
/// @param buffer the index of the vertex buffer
/// @param format the format to read
ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format);
/// Generates an expression reading a uint32 from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos);
/// Generates an expression reading an int32 from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos);
/// Generates an expression reading a float from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos);
/// Generates an expression reading a basic type (u32, i32, f32) from a
/// vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
/// @param format the underlying vertex format
ast::Expression* AccessPrimitive(uint32_t buffer,
ast::Expression* pos,
VertexFormat format);
/// Generates an expression reading a vec2/3/4 from a vertex buffer.
/// This reads the value wherever `kPullingPosVarName` points to at the time
/// of the read.
/// @param buffer the index of the vertex buffer
/// @param element_stride stride between elements, in bytes
/// @param base_type underlying AST type
/// @param base_format underlying vertex format
/// @param count how many elements the vector has
ast::Expression* AccessVec(uint32_t buffer,
uint32_t element_stride,
ast::type::Type* base_type,
VertexFormat base_format,
uint32_t count);
// Used to grab corresponding types from the type manager
ast::type::Type* GetU32Type();
ast::type::Type* GetI32Type();
ast::type::Type* GetF32Type();
ast::Module* const mod;
Config const cfg;
std::unordered_map<uint32_t, ast::Variable*> location_to_var;
std::string vertex_index_name;
std::string instance_index_name;
};
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_VERTEX_PULLING_H_

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/transform/vertex_pulling_transform.h" #include "src/transform/vertex_pulling.h"
#include <utility> #include <utility>
@ -25,6 +25,7 @@
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/void_type.h" #include "src/ast/type/void_type.h"
#include "src/diagnostic/formatter.h"
#include "src/transform/manager.h" #include "src/transform/manager.h"
#include "src/type_determiner.h" #include "src/type_determiner.h"
#include "src/validator/validator.h" #include "src/validator/validator.h"
@ -33,13 +34,12 @@ namespace tint {
namespace transform { namespace transform {
namespace { namespace {
class VertexPullingTransformHelper { class VertexPullingHelper {
public: public:
VertexPullingTransformHelper() { VertexPullingHelper() {
mod_ = std::make_unique<ast::Module>(); mod_ = std::make_unique<ast::Module>();
manager_ = std::make_unique<Manager>(&ctx_, mod_.get()); manager_ = std::make_unique<Manager>();
auto transform = auto transform = std::make_unique<VertexPulling>();
std::make_unique<VertexPullingTransform>(&ctx_, mod_.get());
transform_ = transform.get(); transform_ = transform.get();
manager_->append(std::move(transform)); manager_->append(std::move(transform));
} }
@ -61,8 +61,7 @@ class VertexPullingTransformHelper {
TypeDeterminer td(&ctx_, mod_.get()); TypeDeterminer td(&ctx_, mod_.get());
EXPECT_TRUE(td.Determine()); EXPECT_TRUE(td.Determine());
transform_->SetVertexState( transform_->SetVertexState(vertex_state);
std::make_unique<VertexStateDescriptor>(vertex_state));
transform_->SetEntryPoint("main"); transform_->SetEntryPoint("main");
} }
@ -82,7 +81,7 @@ class VertexPullingTransformHelper {
ast::Module* mod() { return mod_.get(); } ast::Module* mod() { return mod_.get(); }
Manager* manager() { return manager_.get(); } Manager* manager() { return manager_.get(); }
VertexPullingTransform* transform() { return transform_; } VertexPulling* transform() { return transform_; }
/// Creates a new `ast::Node` owned by the Module. When the Module is /// Creates a new `ast::Node` owned by the Module. When the Module is
/// destructed, the `ast::Node` will also be destructed. /// destructed, the `ast::Node` will also be destructed.
@ -97,33 +96,38 @@ class VertexPullingTransformHelper {
Context ctx_; Context ctx_;
std::unique_ptr<ast::Module> mod_; std::unique_ptr<ast::Module> mod_;
std::unique_ptr<Manager> manager_; std::unique_ptr<Manager> manager_;
VertexPullingTransform* transform_; VertexPulling* transform_;
}; };
class VertexPullingTransformTest : public VertexPullingTransformHelper, class VertexPullingTest : public VertexPullingHelper, public testing::Test {};
public testing::Test {};
TEST_F(VertexPullingTransformTest, Error_NoVertexState) { TEST_F(VertexPullingTest, Error_NoVertexState) {
EXPECT_FALSE(manager()->Run()); auto result = manager()->Run(mod());
EXPECT_EQ(manager()->error(), "SetVertexState not called"); EXPECT_TRUE(result.diagnostics.contains_errors());
EXPECT_EQ(diag::Formatter().format(result.diagnostics),
"error: SetVertexState not called");
} }
TEST_F(VertexPullingTransformTest, Error_NoEntryPoint) { TEST_F(VertexPullingTest, Error_NoEntryPoint) {
transform()->SetVertexState(std::make_unique<VertexStateDescriptor>()); transform()->SetVertexState({});
EXPECT_FALSE(manager()->Run()); auto result = manager()->Run(mod());
EXPECT_EQ(manager()->error(), "Vertex stage entry point not found"); EXPECT_TRUE(result.diagnostics.contains_errors());
EXPECT_EQ(diag::Formatter().format(result.diagnostics),
"error: Vertex stage entry point not found");
} }
TEST_F(VertexPullingTransformTest, Error_InvalidEntryPoint) { TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
InitBasicModule(); InitBasicModule();
InitTransform({}); InitTransform({});
transform()->SetEntryPoint("_"); transform()->SetEntryPoint("_");
EXPECT_FALSE(manager()->Run()); auto result = manager()->Run(mod());
EXPECT_EQ(manager()->error(), "Vertex stage entry point not found"); EXPECT_TRUE(result.diagnostics.contains_errors());
EXPECT_EQ(diag::Formatter().format(result.diagnostics),
"error: Vertex stage entry point not found");
} }
TEST_F(VertexPullingTransformTest, Error_EntryPointWrongStage) { TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
auto* func = create<ast::Function>("main", ast::VariableList{}, auto* func = create<ast::Function>("main", ast::VariableList{},
mod()->create<ast::type::Void>(), mod()->create<ast::type::Void>(),
create<ast::BlockStatement>()); create<ast::BlockStatement>());
@ -132,17 +136,20 @@ TEST_F(VertexPullingTransformTest, Error_EntryPointWrongStage) {
mod()->AddFunction(func); mod()->AddFunction(func);
InitTransform({}); InitTransform({});
EXPECT_FALSE(manager()->Run()); auto result = manager()->Run(mod());
EXPECT_EQ(manager()->error(), "Vertex stage entry point not found"); EXPECT_TRUE(result.diagnostics.contains_errors());
EXPECT_EQ(diag::Formatter().format(result.diagnostics),
"error: Vertex stage entry point not found");
} }
TEST_F(VertexPullingTransformTest, BasicModule) { TEST_F(VertexPullingTest, BasicModule) {
InitBasicModule(); InitBasicModule();
InitTransform({}); InitTransform({});
EXPECT_TRUE(manager()->Run()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
} }
TEST_F(VertexPullingTransformTest, OneAttribute) { TEST_F(VertexPullingTest, OneAttribute) {
InitBasicModule(); InitBasicModule();
ast::type::F32 f32; ast::type::F32 f32;
@ -150,7 +157,8 @@ TEST_F(VertexPullingTransformTest, OneAttribute) {
InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
EXPECT_TRUE(manager()->Run()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -223,10 +231,10 @@ TEST_F(VertexPullingTransformTest, OneAttribute) {
} }
} }
)", )",
mod()->to_str()); result.module.to_str());
} }
TEST_F(VertexPullingTransformTest, OneInstancedAttribute) { TEST_F(VertexPullingTest, OneInstancedAttribute) {
InitBasicModule(); InitBasicModule();
ast::type::F32 f32; ast::type::F32 f32;
@ -235,7 +243,8 @@ TEST_F(VertexPullingTransformTest, OneInstancedAttribute) {
InitTransform( InitTransform(
{{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}}); {{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}});
EXPECT_TRUE(manager()->Run()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -308,10 +317,10 @@ TEST_F(VertexPullingTransformTest, OneInstancedAttribute) {
} }
} }
)", )",
mod()->to_str()); result.module.to_str());
} }
TEST_F(VertexPullingTransformTest, OneAttributeDifferentOutputSet) { TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
InitBasicModule(); InitBasicModule();
ast::type::F32 f32; ast::type::F32 f32;
@ -320,7 +329,8 @@ TEST_F(VertexPullingTransformTest, OneAttributeDifferentOutputSet) {
InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
transform()->SetPullingBufferBindingSet(5); transform()->SetPullingBufferBindingSet(5);
EXPECT_TRUE(manager()->Run()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -393,11 +403,11 @@ TEST_F(VertexPullingTransformTest, OneAttributeDifferentOutputSet) {
} }
} }
)", )",
mod()->to_str()); result.module.to_str());
} }
// We expect the transform to use an existing builtin variables if it finds them // We expect the transform to use an existing builtin variables if it finds them
TEST_F(VertexPullingTransformTest, ExistingVertexIndexAndInstanceIndex) { TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
InitBasicModule(); InitBasicModule();
ast::type::F32 f32; ast::type::F32 f32;
@ -435,7 +445,8 @@ TEST_F(VertexPullingTransformTest, ExistingVertexIndexAndInstanceIndex) {
{{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}, {{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}},
{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}}); {4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}});
EXPECT_TRUE(manager()->Run()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -558,10 +569,10 @@ TEST_F(VertexPullingTransformTest, ExistingVertexIndexAndInstanceIndex) {
} }
} }
)", )",
mod()->to_str()); result.module.to_str());
} }
TEST_F(VertexPullingTransformTest, TwoAttributesSameBuffer) { TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
InitBasicModule(); InitBasicModule();
ast::type::F32 f32; ast::type::F32 f32;
@ -575,7 +586,8 @@ TEST_F(VertexPullingTransformTest, TwoAttributesSameBuffer) {
InputStepMode::kVertex, InputStepMode::kVertex,
{{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}); {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}});
EXPECT_TRUE(manager()->Run()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -739,10 +751,10 @@ TEST_F(VertexPullingTransformTest, TwoAttributesSameBuffer) {
} }
} }
)", )",
mod()->to_str()); result.module.to_str());
} }
TEST_F(VertexPullingTransformTest, FloatVectorAttributes) { TEST_F(VertexPullingTest, FloatVectorAttributes) {
InitBasicModule(); InitBasicModule();
ast::type::F32 f32; ast::type::F32 f32;
@ -760,7 +772,8 @@ TEST_F(VertexPullingTransformTest, FloatVectorAttributes) {
{12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}}, {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
{16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}}); {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}});
EXPECT_TRUE(manager()->Run()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -1040,7 +1053,7 @@ TEST_F(VertexPullingTransformTest, FloatVectorAttributes) {
} }
} }
)", )",
mod()->to_str()); result.module.to_str());
} }
} // namespace } // namespace

View File

@ -1,463 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/vertex_pulling_transform.h"
#include <utility>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/stride_decoration.h"
#include "src/ast/struct.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/struct_decoration.h"
#include "src/ast/struct_member.h"
#include "src/ast/struct_member_offset_decoration.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/uint_literal.h"
#include "src/ast/variable_decl_statement.h"
namespace tint {
namespace transform {
namespace {
static const char kVertexBufferNamePrefix[] = "_tint_pulling_vertex_buffer_";
static const char kStructBufferName[] = "_tint_vertex_data";
static const char kStructName[] = "TintVertexData";
static const char kPullingPosVarName[] = "_tint_pulling_pos";
static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index";
static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index";
} // namespace
VertexPullingTransform::VertexPullingTransform(ast::Module* mod)
: Transformer(mod) {}
VertexPullingTransform::VertexPullingTransform(Context*, ast::Module* mod)
: VertexPullingTransform(mod) {}
VertexPullingTransform::~VertexPullingTransform() = default;
void VertexPullingTransform::SetVertexState(
std::unique_ptr<VertexStateDescriptor> vertex_state) {
vertex_state_ = std::move(vertex_state);
}
void VertexPullingTransform::SetEntryPoint(std::string entry_point) {
entry_point_name_ = std::move(entry_point);
}
void VertexPullingTransform::SetPullingBufferBindingSet(uint32_t number) {
pulling_set_ = number;
}
bool VertexPullingTransform::Run() {
// Check SetVertexState was called
if (vertex_state_ == nullptr) {
error_ = "SetVertexState not called";
return false;
}
// Find entry point
auto* func = mod_->FindFunctionByNameAndStage(entry_point_name_,
ast::PipelineStage::kVertex);
if (func == nullptr) {
error_ = "Vertex stage entry point not found";
return false;
}
// Save the vertex function
auto* vertex_func = mod_->FindFunctionByName(func->name());
// TODO(idanr): Need to check shader locations in descriptor cover all
// attributes
// TODO(idanr): Make sure we covered all error cases, to guarantee the
// following stages will pass
FindOrInsertVertexIndexIfUsed();
FindOrInsertInstanceIndexIfUsed();
ConvertVertexInputVariablesToPrivate();
AddVertexStorageBuffers();
AddVertexPullingPreamble(vertex_func);
return true;
}
std::string VertexPullingTransform::GetVertexBufferName(uint32_t index) {
return kVertexBufferNamePrefix + std::to_string(index);
}
void VertexPullingTransform::FindOrInsertVertexIndexIfUsed() {
bool uses_vertex_step_mode = false;
for (const VertexBufferLayoutDescriptor& buffer_layout :
vertex_state_->vertex_buffers) {
if (buffer_layout.step_mode == InputStepMode::kVertex) {
uses_vertex_step_mode = true;
break;
}
}
if (!uses_vertex_step_mode) {
return;
}
// Look for an existing vertex index builtin
for (auto* v : mod_->global_variables()) {
if (!v->Is<ast::DecoratedVariable>() ||
v->storage_class() != ast::StorageClass::kInput) {
continue;
}
for (auto* d : v->As<ast::DecoratedVariable>()->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kVertexIdx) {
vertex_index_name_ = v->name();
return;
}
}
}
}
// We didn't find a vertex index builtin, so create one
vertex_index_name_ = kDefaultVertexIndexName;
auto* var = create<ast::DecoratedVariable>(create<ast::Variable>(
vertex_index_name_, ast::StorageClass::kInput, GetI32Type()));
ast::VariableDecorationList decorations;
decorations.push_back(
create<ast::BuiltinDecoration>(ast::Builtin::kVertexIdx, Source{}));
var->set_decorations(std::move(decorations));
mod_->AddGlobalVariable(var);
}
void VertexPullingTransform::FindOrInsertInstanceIndexIfUsed() {
bool uses_instance_step_mode = false;
for (const VertexBufferLayoutDescriptor& buffer_layout :
vertex_state_->vertex_buffers) {
if (buffer_layout.step_mode == InputStepMode::kInstance) {
uses_instance_step_mode = true;
break;
}
}
if (!uses_instance_step_mode) {
return;
}
// Look for an existing instance index builtin
for (auto* v : mod_->global_variables()) {
if (!v->Is<ast::DecoratedVariable>() ||
v->storage_class() != ast::StorageClass::kInput) {
continue;
}
for (auto* d : v->As<ast::DecoratedVariable>()->decorations()) {
if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
if (builtin->value() == ast::Builtin::kInstanceIdx) {
instance_index_name_ = v->name();
return;
}
}
}
}
// We didn't find an instance index builtin, so create one
instance_index_name_ = kDefaultInstanceIndexName;
auto* var = create<ast::DecoratedVariable>(create<ast::Variable>(
instance_index_name_, ast::StorageClass::kInput, GetI32Type()));
ast::VariableDecorationList decorations;
decorations.push_back(
create<ast::BuiltinDecoration>(ast::Builtin::kInstanceIdx, Source{}));
var->set_decorations(std::move(decorations));
mod_->AddGlobalVariable(var);
}
void VertexPullingTransform::ConvertVertexInputVariablesToPrivate() {
for (auto*& v : mod_->global_variables()) {
if (!v->Is<ast::DecoratedVariable>() ||
v->storage_class() != ast::StorageClass::kInput) {
continue;
}
for (auto* d : v->As<ast::DecoratedVariable>()->decorations()) {
if (auto* l = d->As<ast::LocationDecoration>()) {
uint32_t location = l->value();
// This is where the replacement happens. Expressions use identifier
// strings instead of pointers, so we don't need to update any other
// place in the AST.
v = create<ast::Variable>(v->name(), ast::StorageClass::kPrivate,
v->type());
location_to_var_[location] = v;
break;
}
}
}
}
void VertexPullingTransform::AddVertexStorageBuffers() {
// TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
// The array inside the struct definition
auto internal_array = std::make_unique<ast::type::Array>(GetU32Type());
ast::ArrayDecorationList ary_decos;
ary_decos.push_back(create<ast::StrideDecoration>(4u, Source{}));
internal_array->set_decorations(std::move(ary_decos));
auto* internal_array_type = mod_->unique_type(std::move(internal_array));
// Creating the struct type
ast::StructMemberList members;
ast::StructMemberDecorationList member_dec;
member_dec.push_back(create<ast::StructMemberOffsetDecoration>(0u, Source{}));
members.push_back(create<ast::StructMember>(
kStructBufferName, internal_array_type, std::move(member_dec)));
ast::StructDecorationList decos;
decos.push_back(create<ast::StructBlockDecoration>(Source{}));
auto* struct_type = mod_->create<ast::type::Struct>(
kStructName, create<ast::Struct>(std::move(decos), std::move(members)));
for (uint32_t i = 0; i < vertex_state_->vertex_buffers.size(); ++i) {
// The decorated variable with struct type
auto* var = create<ast::DecoratedVariable>(
create<ast::Variable>(GetVertexBufferName(i),
ast::StorageClass::kStorageBuffer, struct_type));
// Add decorations
ast::VariableDecorationList decorations;
decorations.push_back(create<ast::BindingDecoration>(i, Source{}));
decorations.push_back(create<ast::SetDecoration>(pulling_set_, Source{}));
var->set_decorations(std::move(decorations));
mod_->AddGlobalVariable(var);
}
mod_->AddConstructedType(struct_type);
}
void VertexPullingTransform::AddVertexPullingPreamble(
ast::Function* vertex_func) {
// Assign by looking at the vertex descriptor to find attributes with matching
// location.
// A block statement allowing us to use append instead of insert
auto* block = create<ast::BlockStatement>();
// Declare the |kPullingPosVarName| variable in the shader
auto* pos_declaration =
create<ast::VariableDeclStatement>(create<ast::Variable>(
kPullingPosVarName, ast::StorageClass::kFunction, GetI32Type()));
// |kPullingPosVarName| refers to the byte location of the current read. We
// declare a variable in the shader to avoid having to reuse Expression
// objects.
block->append(pos_declaration);
for (uint32_t i = 0; i < vertex_state_->vertex_buffers.size(); ++i) {
const VertexBufferLayoutDescriptor& buffer_layout =
vertex_state_->vertex_buffers[i];
for (const VertexAttributeDescriptor& attribute_desc :
buffer_layout.attributes) {
auto it = location_to_var_.find(attribute_desc.shader_location);
if (it == location_to_var_.end()) {
continue;
}
auto* v = it->second;
// Identifier to index by
auto* index_identifier = create<ast::IdentifierExpression>(
buffer_layout.step_mode == InputStepMode::kVertex
? vertex_index_name_
: instance_index_name_);
// An expression for the start of the read in the buffer in bytes
auto* pos_value = create<ast::BinaryExpression>(
ast::BinaryOp::kAdd,
create<ast::BinaryExpression>(
ast::BinaryOp::kMultiply, index_identifier,
GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
GenUint(static_cast<uint32_t>(attribute_desc.offset)));
// Update position of the read
auto* set_pos_expr = create<ast::AssignmentStatement>(
CreatePullingPositionIdent(), pos_value);
block->append(set_pos_expr);
block->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>(v->name()),
AccessByFormat(i, attribute_desc.format)));
}
}
vertex_func->body()->insert(0, block);
}
ast::Expression* VertexPullingTransform::GenUint(uint32_t value) {
return create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(GetU32Type(), value));
}
ast::Expression* VertexPullingTransform::CreatePullingPositionIdent() {
return create<ast::IdentifierExpression>(kPullingPosVarName);
}
ast::Expression* VertexPullingTransform::AccessByFormat(uint32_t buffer,
VertexFormat format) {
// TODO(idanr): this doesn't account for the format of the attribute in the
// shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
// right now, we would try to assign a vec4<f32> to this attribute, but we
// really need to assign a vec4<u32> by casting.
// We could split this function to first do memory accesses and unpacking into
// int/uint/float1-4/etc, then convert that variable to a var<in> with the
// conversion defined in the WebGPU spec.
switch (format) {
case VertexFormat::kU32:
return AccessU32(buffer, CreatePullingPositionIdent());
case VertexFormat::kI32:
return AccessI32(buffer, CreatePullingPositionIdent());
case VertexFormat::kF32:
return AccessF32(buffer, CreatePullingPositionIdent());
case VertexFormat::kVec2F32:
return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 2);
case VertexFormat::kVec3F32:
return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 3);
case VertexFormat::kVec4F32:
return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 4);
default:
return nullptr;
}
}
ast::Expression* VertexPullingTransform::AccessU32(uint32_t buffer,
ast::Expression* pos) {
// Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
// has byte offsets for each attribute, and we will convert it to u32 indexes
// by dividing. Then, that element is going to be read, and if needed,
// unpacked into an appropriate variable. All reads should end up here as a
// base case.
return create<ast::ArrayAccessorExpression>(
create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>(GetVertexBufferName(buffer)),
create<ast::IdentifierExpression>(kStructBufferName)),
create<ast::BinaryExpression>(ast::BinaryOp::kDivide, pos, GenUint(4)));
}
ast::Expression* VertexPullingTransform::AccessI32(uint32_t buffer,
ast::Expression* pos) {
// as<T> reinterprets bits
return create<ast::BitcastExpression>(GetI32Type(), AccessU32(buffer, pos));
}
ast::Expression* VertexPullingTransform::AccessF32(uint32_t buffer,
ast::Expression* pos) {
// as<T> reinterprets bits
return create<ast::BitcastExpression>(GetF32Type(), AccessU32(buffer, pos));
}
ast::Expression* VertexPullingTransform::AccessPrimitive(uint32_t buffer,
ast::Expression* pos,
VertexFormat format) {
// This function uses a position expression to read, rather than using the
// position variable. This allows us to read from offset positions relative to
// |kPullingPosVarName|. We can't call AccessByFormat because it reads only
// from the position variable.
switch (format) {
case VertexFormat::kU32:
return AccessU32(buffer, pos);
case VertexFormat::kI32:
return AccessI32(buffer, pos);
case VertexFormat::kF32:
return AccessF32(buffer, pos);
default:
return nullptr;
}
}
ast::Expression* VertexPullingTransform::AccessVec(uint32_t buffer,
uint32_t element_stride,
ast::type::Type* base_type,
VertexFormat base_format,
uint32_t count) {
ast::ExpressionList expr_list;
for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component
auto* cur_pos = create<ast::BinaryExpression>(ast::BinaryOp::kAdd,
CreatePullingPositionIdent(),
GenUint(element_stride * i));
expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
}
return create<ast::TypeConstructorExpression>(
mod_->create<ast::type::Vector>(base_type, count), std::move(expr_list));
}
ast::type::Type* VertexPullingTransform::GetU32Type() {
return mod_->create<ast::type::U32>();
}
ast::type::Type* VertexPullingTransform::GetI32Type() {
return mod_->create<ast::type::I32>();
}
ast::type::Type* VertexPullingTransform::GetF32Type() {
return mod_->create<ast::type::F32>();
}
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
uint64_t in_array_stride,
InputStepMode in_step_mode,
std::vector<VertexAttributeDescriptor> in_attributes)
: array_stride(in_array_stride),
step_mode(in_step_mode),
attributes(std::move(in_attributes)) {}
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
const VertexBufferLayoutDescriptor& other)
: array_stride(other.array_stride),
step_mode(other.step_mode),
attributes(other.attributes) {}
VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
VertexStateDescriptor::VertexStateDescriptor() = default;
VertexStateDescriptor::VertexStateDescriptor(
std::vector<VertexBufferLayoutDescriptor> in_vertex_buffers)
: vertex_buffers(std::move(in_vertex_buffers)) {}
VertexStateDescriptor::VertexStateDescriptor(const VertexStateDescriptor& other)
: vertex_buffers(other.vertex_buffers) {}
VertexStateDescriptor::~VertexStateDescriptor() = default;
} // namespace transform
} // namespace tint

View File

@ -1,269 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_VERTEX_PULLING_TRANSFORM_H_
#define SRC_TRANSFORM_VERTEX_PULLING_TRANSFORM_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "src/ast/expression.h"
#include "src/ast/function.h"
#include "src/ast/module.h"
#include "src/ast/statement.h"
#include "src/ast/variable.h"
#include "src/transform/transformer.h"
namespace tint {
namespace transform {
/// Describes the format of data in a vertex buffer
enum class VertexFormat {
kVec2U8,
kVec4U8,
kVec2I8,
kVec4I8,
kVec2U8Norm,
kVec4U8Norm,
kVec2I8Norm,
kVec4I8Norm,
kVec2U16,
kVec4U16,
kVec2I16,
kVec4I16,
kVec2U16Norm,
kVec4U16Norm,
kVec2I16Norm,
kVec4I16Norm,
kVec2F16,
kVec4F16,
kF32,
kVec2F32,
kVec3F32,
kVec4F32,
kU32,
kVec2U32,
kVec3U32,
kVec4U32,
kI32,
kVec2I32,
kVec3I32,
kVec4I32
};
/// Describes if a vertex attribtes increments with vertex index or instance
/// index
enum class InputStepMode { kVertex, kInstance };
/// Describes a vertex attribute within a buffer
struct VertexAttributeDescriptor {
/// The format of the attribute
VertexFormat format;
/// The byte offset of the attribute in the buffer
uint64_t offset;
/// The shader location used for the attribute
uint32_t shader_location;
};
/// Describes a buffer containing multiple vertex attributes
struct VertexBufferLayoutDescriptor {
/// Constructor
VertexBufferLayoutDescriptor();
/// Constructor
/// @param in_array_stride the array stride of the in buffer
/// @param in_step_mode the step mode of the in buffer
/// @param in_attributes the in attributes
VertexBufferLayoutDescriptor(
uint64_t in_array_stride,
InputStepMode in_step_mode,
std::vector<VertexAttributeDescriptor> in_attributes);
/// Copy constructor
/// @param other the struct to copy
VertexBufferLayoutDescriptor(const VertexBufferLayoutDescriptor& other);
~VertexBufferLayoutDescriptor();
/// The array stride used in the in buffer
uint64_t array_stride = 0u;
/// The input step mode used
InputStepMode step_mode = InputStepMode::kVertex;
/// The vertex attributes
std::vector<VertexAttributeDescriptor> attributes;
};
/// Describes vertex state, which consists of many buffers containing vertex
/// attributes
struct VertexStateDescriptor {
/// Constructor
VertexStateDescriptor();
/// Constructor
/// @param in_vertex_buffers the vertex buffers
VertexStateDescriptor(
std::vector<VertexBufferLayoutDescriptor> in_vertex_buffers);
/// Copy constructor
/// @param other the struct to copy
VertexStateDescriptor(const VertexStateDescriptor& other);
~VertexStateDescriptor();
/// The vertex buffers
std::vector<VertexBufferLayoutDescriptor> vertex_buffers;
};
/// Converts a module to use vertex pulling
///
/// Variables which accept vertex input are var<in> with a location decoration.
/// This transform will convert those to be assigned from storage buffers
/// instead. The intention is to allow vertex input to rely on a storage buffer
/// clamping pass for out of bounds reads. We bind the storage buffers as arrays
/// of u32, so any read to byte position `p` will actually need to read position
/// `p / 4`, since `sizeof(u32) == 4`.
///
/// `VertexFormat` represents the input type of the attribute. This isn't
/// related to the type of the variable in the shader. For example,
/// `VertexFormat::kVec2F16` tells us that the buffer will contain `f16`
/// elements, to be read as vec2. In the shader, a user would make a `vec2<f32>`
/// to be able to use them. The conversion between `f16` and `f32` will need to
/// be handled by us (using unpack functions).
///
/// To be clear, there won't be types such as `f16` or `u8` anywhere in WGSL
/// code, but these are types that the data may arrive as. We need to convert
/// these smaller types into the base types such as `f32` and `u32` for the
/// shader to use.
class VertexPullingTransform : public Transformer {
public:
/// Constructor
/// @param mod the module to convert to vertex pulling
explicit VertexPullingTransform(ast::Module* mod);
/// Constructor
/// DEPRECATED
/// @param ctx the tint context
/// @param mod the module to convert to vertex pulling
VertexPullingTransform(Context* ctx, ast::Module* mod);
~VertexPullingTransform() override;
/// Sets the vertex state descriptor, containing info about attributes
/// @param vertex_state the vertex state descriptor
void SetVertexState(std::unique_ptr<VertexStateDescriptor> vertex_state);
/// Sets the entry point to add assignments into
/// @param entry_point the vertex stage entry point
void SetEntryPoint(std::string entry_point);
/// Sets the "set" we will put all our vertex buffers into (as storage
/// buffers)
/// @param number the set number we will use
void SetPullingBufferBindingSet(uint32_t number);
/// Users of Tint should register the transform with transform manager and
/// invoke its Run(), instead of directly calling the transform's Run().
/// Calling Run() directly does not perform module state cleanup operations.
/// @returns true if the transformation was successful
bool Run() override;
private:
/// Generate the vertex buffer binding name
/// @param index index to append to buffer name
std::string GetVertexBufferName(uint32_t index);
/// Inserts vertex_idx binding, or finds the existing one
void FindOrInsertVertexIndexIfUsed();
/// Inserts instance_idx binding, or finds the existing one
void FindOrInsertInstanceIndexIfUsed();
/// Converts var<in> with a location decoration to var<private>
void ConvertVertexInputVariablesToPrivate();
/// Adds storage buffer decorated variables for the vertex buffers
void AddVertexStorageBuffers();
/// Adds assignment to the variables from the buffers
void AddVertexPullingPreamble(ast::Function* vertex_func);
/// Generates an expression holding a constant uint
/// @param value uint value
ast::Expression* GenUint(uint32_t value);
/// Generates an expression to read the shader value `kPullingPosVarName`
ast::Expression* CreatePullingPositionIdent();
/// Generates an expression reading from a buffer a specific format.
/// This reads the value wherever `kPullingPosVarName` points to at the time
/// of the read.
/// @param buffer the index of the vertex buffer
/// @param format the format to read
ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format);
/// Generates an expression reading a uint32 from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos);
/// Generates an expression reading an int32 from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos);
/// Generates an expression reading a float from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos);
/// Generates an expression reading a basic type (u32, i32, f32) from a vertex
/// buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
/// @param format the underlying vertex format
ast::Expression* AccessPrimitive(uint32_t buffer,
ast::Expression* pos,
VertexFormat format);
/// Generates an expression reading a vec2/3/4 from a vertex buffer.
/// This reads the value wherever `kPullingPosVarName` points to at the time
/// of the read.
/// @param buffer the index of the vertex buffer
/// @param element_stride stride between elements, in bytes
/// @param base_type underlying AST type
/// @param base_format underlying vertex format
/// @param count how many elements the vector has
ast::Expression* AccessVec(uint32_t buffer,
uint32_t element_stride,
ast::type::Type* base_type,
VertexFormat base_format,
uint32_t count);
// Used to grab corresponding types from the type manager
ast::type::Type* GetU32Type();
ast::type::Type* GetI32Type();
ast::type::Type* GetF32Type();
std::string entry_point_name_;
std::string vertex_index_name_;
std::string instance_index_name_;
// Default to 4 as it is past the limits of user-accessible sets
uint32_t pulling_set_ = 4u;
std::unordered_map<uint32_t, ast::Variable*> location_to_var_;
std::unique_ptr<VertexStateDescriptor> vertex_state_;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_VERTEX_PULLING_TRANSFORM_H_