writer/msl: Remove legacy shader IO support

Mostly just deleting unneeded code, and a few additional cleanups as a
result.

Bug: tint:697
Change-Id: I31ceea93feb34994f51a1b6d294a35cf0c127447
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55282
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-06-19 02:37:25 +00:00 committed by Tint LUCI CQ
parent a865b375aa
commit 5eadb83a3f
7 changed files with 22 additions and 987 deletions

View File

@ -896,7 +896,6 @@ if(${TINT_BUILD_TESTS})
writer/msl/generator_impl_constructor_test.cc
writer/msl/generator_impl_continue_test.cc
writer/msl/generator_impl_discard_test.cc
writer/msl/generator_impl_function_entry_point_data_test.cc
writer/msl/generator_impl_function_test.cc
writer/msl/generator_impl_identifier_test.cc
writer/msl/generator_impl_if_test.cc

View File

@ -57,11 +57,6 @@ namespace writer {
namespace msl {
namespace {
const char kInStructNameSuffix[] = "in";
const char kOutStructNameSuffix[] = "out";
const char kTintStructInVarPrefix[] = "_tint_in";
const char kTintStructOutVarPrefix[] = "_tint_out";
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
if (stmts->empty()) {
return false;
@ -82,11 +77,6 @@ bool GeneratorImpl::Generate() {
out_ << "#include <metal_stdlib>" << std::endl << std::endl;
out_ << "using namespace metal;" << std::endl;
for (auto* global : program_->AST().GlobalVariables()) {
auto* sem = program_->Sem().Get(global);
global_variables_.set(global->symbol(), sem);
}
for (auto* const type_decl : program_->AST().TypeDecls()) {
if (!type_decl->Is<ast::Alias>()) {
if (!EmitTypeDecl(TypeOf(type_decl))) {
@ -120,31 +110,16 @@ bool GeneratorImpl::Generate() {
}
}
// Make sure all entry point data is emitted before the entry point functions
for (auto* func : program_->AST().Functions()) {
if (!func->IsEntryPoint()) {
continue;
if (!EmitFunction(func)) {
return false;
}
} else {
if (!EmitEntryPointFunction(func)) {
return false;
}
}
if (!EmitEntryPointData(func)) {
return false;
}
}
for (auto* func : program_->AST().Functions()) {
if (!EmitFunction(func)) {
return false;
}
}
for (auto* func : program_->AST().Functions()) {
if (!func->IsEntryPoint()) {
continue;
}
if (!EmitEntryPointFunction(func)) {
return false;
}
out_ << std::endl;
}
return true;
@ -312,27 +287,6 @@ bool GeneratorImpl::EmitBreak(ast::BreakStatement*) {
return true;
}
std::string GeneratorImpl::current_ep_var_name(VarType type) {
std::string name = "";
switch (type) {
case VarType::kIn: {
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_it != ep_sym_to_in_data_.end()) {
name = in_it->second.var_name;
}
break;
}
case VarType::kOut: {
auto out_it = ep_sym_to_out_data_.find(current_ep_sym_);
if (out_it != ep_sym_to_out_data_.end()) {
name = out_it->second.var_name;
}
break;
}
}
return name;
}
bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
auto* ident = expr->func();
auto* call = program_->Sem().Get(expr);
@ -340,14 +294,6 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
return EmitIntrinsicCall(expr, intrinsic);
}
auto name = program_->Symbols().NameFor(ident->symbol());
auto caller_sym = ident->symbol();
auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" +
caller_sym.to_str());
if (it != ep_func_name_remapped_.end()) {
name = it->second;
}
auto* func = program_->AST().Functions().Find(ident->symbol());
if (func == nullptr) {
diagnostics_.add_error("Unable to find function: " +
@ -355,40 +301,10 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
return false;
}
out_ << name << "(";
out_ << program_->Symbols().NameFor(ident->symbol()) << "(";
bool first = true;
if (has_referenced_in_var_needing_struct(func)) {
auto var_name = current_ep_var_name(VarType::kIn);
if (!var_name.empty()) {
out_ << var_name;
first = false;
}
}
if (has_referenced_out_var_needing_struct(func)) {
auto var_name = current_ep_var_name(VarType::kOut);
if (!var_name.empty()) {
if (!first) {
out_ << ", ";
}
first = false;
out_ << var_name;
}
}
auto* func_sem = program_->Sem().Get(func);
for (const auto& data : func_sem->ReferencedBuiltinVariables()) {
auto* var = data.first;
if (var->StorageClass() != ast::StorageClass::kInput) {
continue;
}
if (!first) {
out_ << ", ";
}
first = false;
out_ << program_->Symbols().NameFor(var->Declaration()->symbol());
}
for (const auto& data : func_sem->ReferencedUniformVariables()) {
auto* var = data.first;
if (!first) {
@ -1061,127 +977,6 @@ bool GeneratorImpl::EmitLiteral(ast::Literal* lit) {
return true;
}
// TODO(crbug.com/tint/697): Remove this when we remove support for entry point
// params as module-scope globals.
bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto* func_sem = program_->Sem().Get(func);
std::vector<std::pair<const ast::Variable*, uint32_t>> in_locations;
std::vector<std::pair<const ast::Variable*, ast::Decoration*>> out_variables;
for (auto data : func_sem->ReferencedLocationVariables()) {
auto* var = data.first;
auto* deco = data.second;
if (var->StorageClass() == ast::StorageClass::kInput) {
in_locations.push_back({var->Declaration(), deco->value()});
} else if (var->StorageClass() == ast::StorageClass::kOutput) {
out_variables.push_back({var->Declaration(), deco});
}
}
for (auto data : func_sem->ReferencedBuiltinVariables()) {
auto* var = data.first;
auto* deco = data.second;
if (var->StorageClass() == ast::StorageClass::kOutput) {
out_variables.push_back({var->Declaration(), deco});
}
}
if (!in_locations.empty()) {
auto in_struct_name =
program_->Symbols().NameFor(func->symbol()) + "_" + kInStructNameSuffix;
auto* in_var_name = kTintStructInVarPrefix;
ep_sym_to_in_data_[func->symbol()] = {in_struct_name, in_var_name};
make_indent();
out_ << "struct " << in_struct_name << " {" << std::endl;
increment_indent();
for (auto& data : in_locations) {
auto* var = data.first;
uint32_t loc = data.second;
make_indent();
if (!EmitType(program_->Sem().Get(var)->Type()->UnwrapRef(),
program_->Symbols().NameFor(var->symbol()))) {
return false;
}
out_ << " " << program_->Symbols().NameFor(var->symbol()) << " [[";
if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
out_ << "attribute(" << loc << ")";
} else if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
out_ << "user(locn" << loc << ")";
} else {
diagnostics_.add_error("invalid location variable for pipeline stage");
return false;
}
out_ << "]];" << std::endl;
}
decrement_indent();
make_indent();
out_ << "};" << std::endl << std::endl;
}
if (!out_variables.empty()) {
auto out_struct_name = program_->Symbols().NameFor(func->symbol()) + "_" +
kOutStructNameSuffix;
auto* out_var_name = kTintStructOutVarPrefix;
ep_sym_to_out_data_[func->symbol()] = {out_struct_name, out_var_name};
make_indent();
out_ << "struct " << out_struct_name << " {" << std::endl;
increment_indent();
for (auto& data : out_variables) {
auto* var = data.first;
auto* deco = data.second;
make_indent();
if (!EmitType(program_->Sem().Get(var)->Type()->UnwrapRef(),
program_->Symbols().NameFor(var->symbol()))) {
return false;
}
out_ << " " << program_->Symbols().NameFor(var->symbol()) << " [[";
if (auto* location = deco->As<ast::LocationDecoration>()) {
auto loc = location->value();
if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
out_ << "user(locn" << loc << ")";
} else if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
out_ << "color(" << loc << ")";
} else {
diagnostics_.add_error(
"invalid location variable for pipeline stage");
return false;
}
} else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
auto attr = builtin_to_attribute(builtin->value());
if (attr.empty()) {
diagnostics_.add_error("unsupported builtin");
return false;
}
out_ << attr;
} else {
diagnostics_.add_error(
"unsupported variable decoration for entry point output");
return false;
}
out_ << "]];" << std::endl;
}
decrement_indent();
make_indent();
out_ << "};" << std::endl << std::endl;
}
return true;
}
bool GeneratorImpl::EmitExpression(ast::Expression* expr) {
if (auto* a = expr->As<ast::ArrayAccessorExpression>()) {
return EmitArrayAccessor(a);
@ -1229,139 +1024,17 @@ void GeneratorImpl::EmitStage(ast::PipelineStage stage) {
return;
}
bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) {
auto* func_sem = program_->Sem().Get(func);
for (auto data : func_sem->ReferencedLocationVariables()) {
auto* var = data.first;
if (var->StorageClass() == ast::StorageClass::kInput) {
return true;
}
}
return false;
}
bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) {
auto* func_sem = program_->Sem().Get(func);
for (auto data : func_sem->ReferencedLocationVariables()) {
auto* var = data.first;
if (var->StorageClass() == ast::StorageClass::kOutput) {
return true;
}
}
for (auto data : func_sem->ReferencedBuiltinVariables()) {
auto* var = data.first;
if (var->StorageClass() == ast::StorageClass::kOutput) {
return true;
}
}
return false;
}
bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
return has_referenced_in_var_needing_struct(func) ||
has_referenced_out_var_needing_struct(func);
}
bool GeneratorImpl::EmitFunction(ast::Function* func) {
auto* func_sem = program_->Sem().Get(func);
make_indent();
// Entry points will be emitted later, skip for now.
if (func->IsEntryPoint()) {
return true;
}
// TODO(dsinclair): This could be smarter. If the input/outputs for multiple
// entry points are the same we could generate a single struct and then have
// this determine it's the same struct and just emit once.
bool emit_duplicate_functions = func_sem->AncestorEntryPoints().size() > 0 &&
has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) {
for (const auto& ep_sym : func_sem->AncestorEntryPoints()) {
if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) {
return false;
}
out_ << std::endl;
}
} else {
// Emit as non-duplicated
if (!EmitFunctionInternal(func, false, Symbol())) {
return false;
}
out_ << std::endl;
}
return true;
}
bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
Symbol ep_sym) {
auto* func_sem = program_->Sem().Get(func);
auto name = func->symbol().to_str();
if (!EmitType(func_sem->ReturnType(), "")) {
return false;
}
out_ << " ";
if (emit_duplicate_functions) {
auto func_name = name;
auto ep_name = ep_sym.to_str();
name = program_->Symbols().NameFor(func->symbol()) + "_" +
program_->Symbols().NameFor(ep_sym);
ep_func_name_remapped_[ep_name + "_" + func_name] = name;
} else {
name = program_->Symbols().NameFor(func->symbol());
}
out_ << name << "(";
out_ << " " << program_->Symbols().NameFor(func->symbol()) << "(";
bool first = true;
// If we're emitting duplicate functions that means the function takes
// the stage_in or stage_out value from the entry point, emit them.
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
auto in_it = ep_sym_to_in_data_.find(ep_sym);
if (in_it != ep_sym_to_in_data_.end()) {
out_ << "thread " << in_it->second.struct_name << "& "
<< in_it->second.var_name;
first = false;
}
auto out_it = ep_sym_to_out_data_.find(ep_sym);
if (out_it != ep_sym_to_out_data_.end()) {
if (!first) {
out_ << ", ";
}
out_ << "thread " << out_it->second.struct_name << "& "
<< out_it->second.var_name;
first = false;
}
}
for (const auto& data : func_sem->ReferencedBuiltinVariables()) {
auto* var = data.first;
if (var->StorageClass() != ast::StorageClass::kInput) {
continue;
}
if (!first) {
out_ << ", ";
}
first = false;
out_ << "thread ";
if (!EmitType(var->Type()->UnwrapRef(), "")) {
return false;
}
out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol());
}
for (const auto& data : func_sem->ReferencedUniformVariables()) {
auto* var = data.first;
if (!first) {
@ -1416,13 +1089,11 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
out_ << ") ";
current_ep_sym_ = ep_sym;
if (!EmitBlockAndNewline(func->body())) {
return false;
}
current_ep_sym_ = Symbol();
out_ << std::endl;
return true;
}
@ -1462,37 +1133,12 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
make_indent();
current_ep_sym_ = func->symbol();
EmitStage(func->pipeline_stage());
out_ << " ";
// This is an entry point, the return type is the entry point output structure
// if one exists, or void otherwise.
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
bool has_out_data = out_data != ep_sym_to_out_data_.end();
if (has_out_data) {
// TODO(crbug.com/tint/697): Remove this.
if (!func->return_type()->Is<ast::Void>()) {
TINT_ICE(diagnostics_) << "Mixing module-scope variables and return "
"types for shader outputs";
}
out_ << out_data->second.struct_name;
} else {
out_ << func->return_type()->FriendlyName(program_->Symbols());
}
out_ << " " << func->return_type()->FriendlyName(program_->Symbols());
out_ << " " << program_->Symbols().NameFor(func->symbol()) << "(";
bool first = true;
// TODO(crbug.com/tint/697): Remove this.
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_);
if (in_data != ep_sym_to_in_data_.end()) {
out_ << in_data->second.struct_name << " " << in_data->second.var_name
<< " [[stage_in]]";
first = false;
}
// Emit entry point parameters.
bool first = true;
for (auto* var : func->params()) {
if (!first) {
out_ << ", ";
@ -1549,33 +1195,6 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
}
}
// TODO(crbug.com/tint/697): Remove this.
for (auto data : func_sem->ReferencedBuiltinVariables()) {
auto* var = data.first;
if (var->StorageClass() != ast::StorageClass::kInput) {
continue;
}
if (!first) {
out_ << ", ";
}
first = false;
auto* builtin = data.second;
if (!EmitType(var->Type()->UnwrapRef(), "")) {
return false;
}
auto attr = builtin_to_attribute(builtin->value());
if (attr.empty()) {
diagnostics_.add_error("unknown builtin");
return false;
}
out_ << " " << program_->Symbols().NameFor(var->Declaration()->symbol())
<< " [[" << attr << "]]";
}
for (auto data : func_sem->ReferencedUniformVariables()) {
if (!first) {
out_ << ", ";
@ -1634,13 +1253,6 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
increment_indent();
if (has_out_data) {
make_indent();
out_ << out_data->second.struct_name << " " << out_data->second.var_name
<< " = {};" << std::endl;
}
generating_entry_point_ = true;
for (auto* s : *func->body()) {
if (!EmitStatement(s)) {
return false;
@ -1654,49 +1266,15 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
return false;
}
}
generating_entry_point_ = false;
decrement_indent();
make_indent();
out_ << "}" << std::endl;
current_ep_sym_ = Symbol();
out_ << "}" << std::endl << std::endl;
return true;
}
bool GeneratorImpl::global_is_in_struct(const sem::Variable* var) const {
auto& decorations = var->Declaration()->decorations();
bool in_or_out_struct_has_location =
var != nullptr &&
ast::HasDecoration<ast::LocationDecoration>(decorations) &&
(var->StorageClass() == ast::StorageClass::kInput ||
var->StorageClass() == ast::StorageClass::kOutput);
bool in_struct_has_builtin =
var != nullptr &&
ast::HasDecoration<ast::BuiltinDecoration>(decorations) &&
var->StorageClass() == ast::StorageClass::kOutput;
return in_or_out_struct_has_location || in_struct_has_builtin;
}
bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
auto* ident = expr->As<ast::IdentifierExpression>();
const sem::Variable* var = nullptr;
if (global_variables_.get(ident->symbol(), &var)) {
if (global_is_in_struct(var)) {
auto var_type = var->StorageClass() == ast::StorageClass::kInput
? VarType::kIn
: VarType::kOut;
auto name = current_ep_var_name(var_type);
if (name.empty()) {
diagnostics_.add_error("unable to find entry point data for variable");
return false;
}
out_ << name << ".";
}
}
out_ << program_->Symbols().NameFor(ident->symbol());
out_ << program_->Symbols().NameFor(expr->symbol());
return true;
}
@ -1865,14 +1443,6 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
make_indent();
out_ << "return";
// TODO(crbug.com/tint/697): Remove this conditional.
if (generating_entry_point_) {
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_);
if (out_data != ep_sym_to_out_data_.end()) {
out_ << " " << out_data->second.var_name;
}
}
if (stmt->has_value()) {
out_ << " ";
if (!EmitExpression(stmt->value())) {
@ -2363,8 +1933,7 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var,
}
} else if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kFunction ||
var->StorageClass() == ast::StorageClass::kNone ||
var->StorageClass() == ast::StorageClass::kOutput) {
var->StorageClass() == ast::StorageClass::kNone) {
out_ << " = ";
if (!EmitZeroValue(type)) {
return false;

View File

@ -16,7 +16,6 @@
#define SRC_WRITER_MSL_GENERATOR_IMPL_H_
#include <string>
#include <unordered_map>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h"
@ -133,10 +132,6 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitElse(ast::ElseStatement* stmt);
/// Handles emitting information for an entry point
/// @param func the entry point function
/// @returns true if the entry point data was emitted
bool EmitEntryPointData(ast::Function* func);
/// Handles emitting the entry point function
/// @param func the entry point function
/// @returns true if the entry point function was emitted
@ -149,15 +144,6 @@ class GeneratorImpl : public TextGenerator {
/// @param func the function to generate
/// @returns true if the function was emitted
bool EmitFunction(ast::Function* func);
/// Internal helper for emitting functions
/// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point
/// @param ep_sym the current entry point or symbol::kInvalid if not set
/// @returns true if the function was emitted.
bool EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
Symbol ep_sym);
/// Handles generating an identifier expression
/// @param expr the identifier expression
/// @returns true if the identifier was emitted
@ -235,44 +221,17 @@ class GeneratorImpl : public TextGenerator {
/// @returns true if the zero value was successfully emitted.
bool EmitZeroValue(const sem::Type* type);
/// Determines if the function needs the input struct passed to it.
/// @param func the function to check
/// @returns true if there are input struct variables used in the function
bool has_referenced_in_var_needing_struct(ast::Function* func);
/// Determines if the function needs the output struct passed to it.
/// @param func the function to check
/// @returns true if there are output struct variables used in the function
bool has_referenced_out_var_needing_struct(ast::Function* func);
/// Determines if any used module variable requires an input or output struct.
/// @param func the function to check
/// @returns true if an input or output struct is required.
bool has_referenced_var_needing_struct(ast::Function* func);
/// Handles generating a builtin name
/// @param intrinsic the semantic info for the intrinsic
/// @returns the name or "" if not valid
std::string generate_builtin_name(const sem::Intrinsic* intrinsic);
/// Checks if the global variable is in an input or output struct
/// @param var the variable to check
/// @returns true if the global is in an input or output struct
bool global_is_in_struct(const sem::Variable* var) const;
/// Converts a builtin to an attribute name
/// @param builtin the builtin to convert
/// @returns the string name of the builtin or blank on error
std::string builtin_to_attribute(ast::Builtin builtin) const;
private:
enum class VarType { kIn, kOut };
struct EntryPointData {
std::string struct_name;
std::string var_name;
};
std::string current_ep_var_name(VarType type);
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
sem::Type* TypeOf(ast::Expression* expr) const {
@ -301,19 +260,8 @@ class GeneratorImpl : public TextGenerator {
/// type.
SizeAndAlign MslPackedTypeSizeAndAlign(const sem::Type* ty);
ScopeStack<const sem::Variable*> global_variables_;
Symbol current_ep_sym_;
bool generating_entry_point_ = false;
const Program* program_ = nullptr;
uint32_t loop_emission_counter_ = 0;
std::unordered_map<Symbol, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<Symbol, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did
// not need to be remapped for the entry point and can be emitted directly.
std::unordered_map<std::string, std::string> ep_func_name_remapped_;
};
} // namespace msl

View File

@ -42,8 +42,8 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) {
Param(Sym(), ty.f32()),
},
ty.void_(), ast::StatementList{}, ast::DecorationList{});
Global("param1", ty.f32(), ast::StorageClass::kInput);
Global("param2", ty.f32(), ast::StorageClass::kInput);
Global("param1", ty.f32(), ast::StorageClass::kPrivate);
Global("param2", ty.f32(), ast::StorageClass::kPrivate);
auto* call = Call("my_func", "param1", "param2");
WrapInFunction(call);
@ -61,8 +61,8 @@ TEST_F(MslGeneratorImplTest, EmitStatement_Call) {
Param(Sym(), ty.f32()),
},
ty.void_(), ast::StatementList{}, ast::DecorationList{});
Global("param1", ty.f32(), ast::StorageClass::kInput);
Global("param2", ty.f32(), ast::StorageClass::kInput);
Global("param1", ty.f32(), ast::StorageClass::kPrivate);
Global("param2", ty.f32(), ast::StorageClass::kPrivate);
auto* call = Call("my_func", "param1", "param2");
auto* stmt = create<ast::CallStatement>(call);

View File

@ -1,271 +0,0 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/stage_decoration.h"
#include "src/writer/msl/test_helper.h"
namespace tint {
namespace writer {
namespace msl {
namespace {
using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Input) {
// [[location 0]] var<in> foo : f32;
// [[location 1]] var<in> bar : i32;
//
// struct vtx_main_in {
// float foo [[attribute(0)]];
// int bar [[attribute(1)]];
// };
Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{Location(0)});
Global("bar", ty.i32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{Location(1)});
auto body = ast::StatementList{
Assign("foo", "foo"),
Assign("bar", "bar"),
Return(Construct(ty.vec4<f32>())),
};
Func("vtx_main", ast::VariableList{}, ty.vec4<f32>(), body,
{Stage(ast::PipelineStage::kVertex)},
{Builtin(ast::Builtin::kPosition)});
GeneratorImpl& gen = Build();
auto* func = program->AST().Functions()[0];
ASSERT_TRUE(gen.EmitEntryPointData(func)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct vtx_main_in {
float foo [[attribute(0)]];
int bar [[attribute(1)]];
};
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Output) {
// [[location 0]] var<out> foo : f32;
// [[location 1]] var<out> bar : i32;
//
// struct vtx_main_out {
// float foo [[user(locn0)]];
// int bar [[user(locn1)]];
// };
Global("foo", ty.f32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{Location(0)});
Global("bar", ty.i32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{Location(1)});
auto body = ast::StatementList{
Assign("foo", "foo"),
Assign("bar", "bar"),
Return(Construct(ty.vec4<f32>())),
};
Func("vtx_main", ast::VariableList{}, ty.vec4<f32>(), body,
{Stage(ast::PipelineStage::kVertex)},
{Builtin(ast::Builtin::kPosition)});
GeneratorImpl& gen = Build();
auto* func = program->AST().Functions()[0];
ASSERT_TRUE(gen.EmitEntryPointData(func)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct vtx_main_out {
float foo [[user(locn0)]];
int bar [[user(locn1)]];
};
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Input) {
// [[location 0]] var<in> foo : f32;
// [[location 1]] var<in> bar : i32;
//
// struct frag_main_in {
// float foo [[user(locn0)]];
// int bar [[user(locn1)]];
// };
Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{Location(0)});
Global("bar", ty.i32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{Location(1)});
auto body = ast::StatementList{
Assign("foo", "foo"),
Assign("bar", "bar"),
};
Func("main", ast::VariableList{}, ty.void_(), body,
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
auto* func = program->AST().Functions()[0];
ASSERT_TRUE(gen.EmitEntryPointData(func)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct main_in {
float foo [[user(locn0)]];
int bar [[user(locn1)]];
};
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Output) {
// [[location 0]] var<out> foo : f32;
// [[location 1]] var<out> bar : i32;
//
// struct frag_main_out {
// float foo [[color(0)]];
// int bar [[color(1)]];
// };
Global("foo", ty.f32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{Location(0)});
Global("bar", ty.i32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{Location(1)});
auto body = ast::StatementList{
Assign("foo", "foo"),
Assign("bar", "bar"),
};
Func("main", ast::VariableList{}, ty.void_(), body,
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
auto* func = program->AST().Functions()[0];
ASSERT_TRUE(gen.EmitEntryPointData(func)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct main_out {
float foo [[color(0)]];
int bar [[color(1)]];
};
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Input) {
// [[location 0]] var<in> foo : f32;
// [[location 1]] var<in> bar : i32;
//
// -> Error, not allowed
Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{Location(0)});
Global("bar", ty.i32(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{Location(1)});
auto body = ast::StatementList{
Assign("foo", "foo"),
Assign("bar", "bar"),
};
Func("main", ast::VariableList{}, ty.void_(), body,
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
auto* func = program->AST().Functions()[0];
ASSERT_FALSE(gen.EmitEntryPointData(func)) << gen.error();
EXPECT_EQ(gen.error(),
R"(error: invalid location variable for pipeline stage)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Output) {
// [[location 0]] var<out> foo : f32;
// [[location 1]] var<out> bar : i32;
//
// -> Error not allowed
Global("foo", ty.f32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{Location(0)});
Global("bar", ty.i32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{Location(1)});
auto body = ast::StatementList{
Assign("foo", "foo"),
Assign("bar", "bar"),
};
Func("main", ast::VariableList{}, ty.void_(), body,
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
auto* func = program->AST().Functions()[0];
ASSERT_FALSE(gen.EmitEntryPointData(func)) << gen.error();
EXPECT_EQ(gen.error(),
R"(error: invalid location variable for pipeline stage)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Builtins) {
// Output builtins go in the output struct, input builtins will be passed
// as input parameters to the entry point function.
// [[builtin position]] var<in> coord : vec4<f32>;
// [[builtin frag_depth]] var<out> depth : f32;
//
// struct main_out {
// float depth [[depth(any)]];
// };
Global("coord", ty.vec4<f32>(), ast::StorageClass::kInput, nullptr,
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr,
ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
auto body = ast::StatementList{Assign("depth", MemberAccessor("coord", "x"))};
Func("main", ast::VariableList{}, ty.void_(), body,
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
auto* func = program->AST().Functions()[0];
ASSERT_TRUE(gen.EmitEntryPointData(func)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct main_out {
float depth [[depth(any)]];
};
)");
}
} // namespace
} // namespace msl
} // namespace writer
} // namespace tint

View File

@ -382,171 +382,6 @@ fragment void frag_main(const device Data& coord [[buffer(0)]]) {
)");
}
// TODO(crbug.com/tint/697): Remove this test
TEST_F(
MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { // NOLINT
Global("foo", ty.f32(), ast::StorageClass::kInput,
ast::DecorationList{Location(0)});
Global("bar", ty.f32(), ast::StorageClass::kOutput,
ast::DecorationList{
Location(1),
});
Global("val", ty.f32(), ast::StorageClass::kOutput,
ast::DecorationList{Location(0)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
auto body = ast::StatementList{Assign("bar", "foo"), Assign("val", "param"),
Return("foo")};
Func("sub_func", params, ty.f32(), body, {});
body = ast::StatementList{
Assign("bar", Call("sub_func", 1.0f)),
Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct ep_1_in {
float foo [[user(locn0)]];
};
struct ep_1_out {
float bar [[color(1)]];
float val [[color(0)]];
};
float sub_func_ep_1(thread ep_1_in& _tint_in, thread ep_1_out& _tint_out, float param) {
_tint_out.bar = _tint_in.foo;
_tint_out.val = param;
return _tint_in.foo;
}
fragment ep_1_out ep_1(ep_1_in _tint_in [[stage_in]]) {
ep_1_out _tint_out = {};
_tint_out.bar = sub_func_ep_1(_tint_in, _tint_out, 1.0f);
return _tint_out;
}
)");
}
// TODO(crbug.com/tint/697): Remove this test
TEST_F(MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoints_NoUsedGlobals) {
Global("depth", ty.f32(), ast::StorageClass::kOutput,
ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
Func("sub_func", params, ty.f32(),
ast::StatementList{
Return("param"),
},
{});
auto body = ast::StatementList{
Assign("depth", Call("sub_func", 1.0f)),
Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct ep_1_out {
float depth [[depth(any)]];
};
float sub_func(float param) {
return param;
}
fragment ep_1_out ep_1() {
ep_1_out _tint_out = {};
_tint_out.depth = sub_func(1.0f);
return _tint_out;
}
)");
}
// TODO(crbug.com/tint/697): Remove this test
TEST_F(
MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { // NOLINT
Global("coord", ty.vec4<f32>(), ast::StorageClass::kInput,
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
Global("depth", ty.f32(), ast::StorageClass::kOutput,
ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
ast::VariableList params;
params.push_back(Param("param", ty.f32()));
auto body = ast::StatementList{
Assign("depth", MemberAccessor("coord", "x")),
Return("param"),
};
Func("sub_func", params, ty.f32(), body, {});
body = ast::StatementList{
Assign("depth", Call("sub_func", 1.0f)),
Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct ep_1_out {
float depth [[depth(any)]];
};
float sub_func_ep_1(thread ep_1_out& _tint_out, thread float4& coord, float param) {
_tint_out.depth = coord.x;
return param;
}
fragment ep_1_out ep_1(float4 coord [[position]]) {
ep_1_out _tint_out = {};
_tint_out.depth = sub_func_ep_1(_tint_out, coord, 1.0f);
return _tint_out;
}
)");
}
TEST_F(MslGeneratorImplTest,
Emit_Decoration_Called_By_EntryPoint_With_Uniform) {
auto* ubo_ty = Structure("UBO", {Member("coord", ty.vec4<f32>())},
@ -715,51 +550,6 @@ fragment void frag_main(const device Data& coord [[buffer(0)]]) {
)");
}
// TODO(crbug.com/tint/697): Remove this test
TEST_F(MslGeneratorImplTest,
Emit_Decoration_EntryPoints_WithGlobal_Nested_Return) {
Global("bar", ty.f32(), ast::StorageClass::kOutput,
ast::DecorationList{
Location(1),
});
auto* list = Block(Return());
auto body = ast::StatementList{
Assign("bar", Expr(1.f)),
create<ast::IfStatement>(create<ast::BinaryExpression>(
ast::BinaryOp::kEqual, Expr(1), Expr(1)),
list, ast::ElseStatementList{}),
Return(),
};
Func("ep_1", ast::VariableList{}, ty.void_(), body,
{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct ep_1_out {
float bar [[color(1)]];
};
fragment ep_1_out ep_1() {
ep_1_out _tint_out = {};
_tint_out.bar = 1.0f;
if ((1 == 1)) {
return _tint_out;
}
return _tint_out;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) {
ast::VariableList params;
params.push_back(Param("a", ty.array<f32, 5>()));

View File

@ -65,8 +65,8 @@ TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) {
}
TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
Global("lhs", ty.f32(), ast::StorageClass::kInput);
Global("rhs", ty.f32(), ast::StorageClass::kInput);
Global("lhs", ty.f32(), ast::StorageClass::kPrivate);
Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(Return());
@ -130,7 +130,7 @@ TEST_F(MslGeneratorImplTest, Emit_LoopWithVarUsedInContinuing) {
// }
// }
Global("rhs", ty.f32(), ast::StorageClass::kInput);
Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
auto* var = Var("lhs", ty.f32(), ast::StorageClass::kNone, Expr(2.4f));