// Copyright 2020 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/writer/msl/generator_impl.h" #include #include #include #include #include "src/ast/bool_literal.h" #include "src/ast/call_statement.h" #include "src/ast/constant_id_decoration.h" #include "src/ast/fallthrough_statement.h" #include "src/ast/float_literal.h" #include "src/ast/module.h" #include "src/ast/sint_literal.h" #include "src/ast/uint_literal.h" #include "src/ast/variable_decl_statement.h" #include "src/semantic/array.h" #include "src/semantic/call.h" #include "src/semantic/function.h" #include "src/semantic/member_accessor_expression.h" #include "src/semantic/struct.h" #include "src/semantic/variable.h" #include "src/type/access_control_type.h" #include "src/type/alias_type.h" #include "src/type/array_type.h" #include "src/type/bool_type.h" #include "src/type/depth_texture_type.h" #include "src/type/f32_type.h" #include "src/type/i32_type.h" #include "src/type/matrix_type.h" #include "src/type/multisampled_texture_type.h" #include "src/type/pointer_type.h" #include "src/type/sampled_texture_type.h" #include "src/type/storage_texture_type.h" #include "src/type/u32_type.h" #include "src/type/vector_type.h" #include "src/type/void_type.h" #include "src/writer/float_to_string.h" namespace tint { namespace writer { namespace msl { namespace { const char kInStructNameSuffix[] = "in"; const char kOutStructNameSuffix[] = "out"; const char kTintStructInVarPrefix[] = "_tint_in"; const char kTintStructOutVarPrefix[] = "_tint_out"; bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { if (stmts->empty()) { return false; } return stmts->last()->Is() || stmts->last()->Is(); } } // namespace GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(), program_(program) {} GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate() { out_ << "#include " << std::endl << std::endl; out_ << "using namespace metal;" << std::endl; for (auto* global : program_->AST().GlobalVariables()) { auto* sem = program_->Sem().Get(global); global_variables_.set(global->symbol(), sem); } for (auto* const ty : program_->AST().ConstructedTypes()) { if (!EmitConstructedType(ty)) { return false; } } if (!program_->AST().ConstructedTypes().empty()) { out_ << std::endl; } for (auto* var : program_->AST().GlobalVariables()) { if (!var->is_const()) { continue; } if (!EmitProgramConstVariable(var)) { return false; } } // Make sure all entry point data is emitted before the entry point functions for (auto* func : program_->AST().Functions()) { if (!func->IsEntryPoint()) { continue; } if (!EmitEntryPointData(func)) { return false; } } for (auto* func : program_->AST().Functions()) { if (!EmitFunction(func)) { return false; } } for (auto* func : program_->AST().Functions()) { if (!func->IsEntryPoint()) { continue; } if (!EmitEntryPointFunction(func)) { return false; } out_ << std::endl; } return true; } bool GeneratorImpl::EmitConstructedType(const type::Type* ty) { make_indent(); if (auto* alias = ty->As()) { out_ << "typedef "; if (!EmitType(alias->type(), "")) { return false; } out_ << " " << program_->Symbols().NameFor(alias->symbol()) << ";" << std::endl; } else if (auto* str = ty->As()) { if (!EmitStructType(str)) { return false; } } else { diagnostics_.add_error("unknown alias type: " + ty->type_name()); return false; } return true; } bool GeneratorImpl::EmitArrayAccessor(ast::ArrayAccessorExpression* expr) { if (!EmitExpression(expr->array())) { return false; } out_ << "["; if (!EmitExpression(expr->idx_expr())) { return false; } out_ << "]"; return true; } bool GeneratorImpl::EmitBitcast(ast::BitcastExpression* expr) { out_ << "as_type<"; if (!EmitType(expr->type(), "")) { return false; } out_ << ">("; if (!EmitExpression(expr->expr())) { return false; } out_ << ")"; return true; } bool GeneratorImpl::EmitAssign(ast::AssignmentStatement* stmt) { make_indent(); if (!EmitExpression(stmt->lhs())) { return false; } out_ << " = "; if (!EmitExpression(stmt->rhs())) { return false; } out_ << ";" << std::endl; return true; } bool GeneratorImpl::EmitBinary(ast::BinaryExpression* expr) { out_ << "("; if (!EmitExpression(expr->lhs())) { return false; } out_ << " "; switch (expr->op()) { case ast::BinaryOp::kAnd: out_ << "&"; break; case ast::BinaryOp::kOr: out_ << "|"; break; case ast::BinaryOp::kXor: out_ << "^"; break; case ast::BinaryOp::kLogicalAnd: out_ << "&&"; break; case ast::BinaryOp::kLogicalOr: out_ << "||"; break; case ast::BinaryOp::kEqual: out_ << "=="; break; case ast::BinaryOp::kNotEqual: out_ << "!="; break; case ast::BinaryOp::kLessThan: out_ << "<"; break; case ast::BinaryOp::kGreaterThan: out_ << ">"; break; case ast::BinaryOp::kLessThanEqual: out_ << "<="; break; case ast::BinaryOp::kGreaterThanEqual: out_ << ">="; break; case ast::BinaryOp::kShiftLeft: out_ << "<<"; break; case ast::BinaryOp::kShiftRight: // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has // implementation-defined behaviour for negative LHS. We may have to // generate extra code to implement WGSL-specified behaviour for negative // LHS. out_ << R"(>>)"; break; case ast::BinaryOp::kAdd: out_ << "+"; break; case ast::BinaryOp::kSubtract: out_ << "-"; break; case ast::BinaryOp::kMultiply: out_ << "*"; break; case ast::BinaryOp::kDivide: out_ << "/"; break; case ast::BinaryOp::kModulo: out_ << "%"; break; case ast::BinaryOp::kNone: diagnostics_.add_error("missing binary operation type"); return false; } out_ << " "; if (!EmitExpression(expr->rhs())) { return false; } out_ << ")"; return true; } bool GeneratorImpl::EmitBreak(ast::BreakStatement*) { make_indent(); out_ << "break;" << std::endl; return true; } std::string GeneratorImpl::current_ep_var_name(VarType type) { std::string name = ""; switch (type) { case VarType::kIn: { auto in_it = ep_sym_to_in_data_.find(current_ep_sym_); if (in_it != ep_sym_to_in_data_.end()) { name = in_it->second.var_name; } break; } case VarType::kOut: { auto out_it = ep_sym_to_out_data_.find(current_ep_sym_); if (out_it != ep_sym_to_out_data_.end()) { name = out_it->second.var_name; } break; } } return name; } bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { auto* ident = expr->func()->As(); if (ident == nullptr) { diagnostics_.add_error("invalid function name"); return 0; } auto* call = program_->Sem().Get(expr); if (auto* intrinsic = call->Target()->As()) { if (intrinsic->IsTexture()) { return EmitTextureCall(expr, intrinsic); } if (intrinsic->Type() == semantic::IntrinsicType::kPack2x16Float || intrinsic->Type() == semantic::IntrinsicType::kUnpack2x16Float) { make_indent(); if (intrinsic->Type() == semantic::IntrinsicType::kPack2x16Float) { out_ << "as_type(half2("; } else { out_ << "float2(as_type("; } if (!EmitExpression(expr->params()[0])) { return false; } out_ << "))"; return true; } // TODO(crbug.com/tint/661): Combine sequential barriers to a single // instruction. if (intrinsic->Type() == semantic::IntrinsicType::kStorageBarrier) { make_indent(); out_ << "threadgroup_barrier(mem_flags::mem_device)"; return true; } if (intrinsic->Type() == semantic::IntrinsicType::kWorkgroupBarrier) { make_indent(); out_ << "threadgroup_barrier(mem_flags::mem_threadgroup)"; return true; } auto name = generate_builtin_name(intrinsic); if (name.empty()) { return false; } make_indent(); out_ << name << "("; bool first = true; const auto& params = expr->params(); for (auto* param : params) { if (!first) { out_ << ", "; } first = false; if (!EmitExpression(param)) { return false; } } out_ << ")"; return true; } auto name = program_->Symbols().NameFor(ident->symbol()); auto caller_sym = ident->symbol(); auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" + caller_sym.to_str()); if (it != ep_func_name_remapped_.end()) { name = it->second; } auto* func = program_->AST().Functions().Find(ident->symbol()); if (func == nullptr) { diagnostics_.add_error("Unable to find function: " + program_->Symbols().NameFor(ident->symbol())); return false; } out_ << name << "("; bool first = true; if (has_referenced_in_var_needing_struct(func)) { auto var_name = current_ep_var_name(VarType::kIn); if (!var_name.empty()) { out_ << var_name; first = false; } } if (has_referenced_out_var_needing_struct(func)) { auto var_name = current_ep_var_name(VarType::kOut); if (!var_name.empty()) { if (!first) { out_ << ", "; } first = false; out_ << var_name; } } auto* func_sem = program_->Sem().Get(func); for (const auto& data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->StorageClass() != ast::StorageClass::kInput) { continue; } if (!first) { out_ << ", "; } first = false; out_ << program_->Symbols().NameFor(var->Declaration()->symbol()); } for (const auto& data : func_sem->ReferencedUniformVariables()) { auto* var = data.first; if (!first) { out_ << ", "; } first = false; out_ << program_->Symbols().NameFor(var->Declaration()->symbol()); } for (const auto& data : func_sem->ReferencedStorageBufferVariables()) { auto* var = data.first; if (!first) { out_ << ", "; } first = false; out_ << program_->Symbols().NameFor(var->Declaration()->symbol()); } const auto& params = expr->params(); for (auto* param : params) { if (!first) { out_ << ", "; } first = false; if (!EmitExpression(param)) { return false; } } out_ << ")"; return true; } bool GeneratorImpl::EmitTextureCall(ast::CallExpression* expr, const semantic::Intrinsic* intrinsic) { using Usage = semantic::Parameter::Usage; auto parameters = intrinsic->Parameters(); auto arguments = expr->params(); // Returns the argument with the given usage auto arg = [&](Usage usage) { int idx = semantic::IndexOf(parameters, usage); return (idx >= 0) ? arguments[idx] : nullptr; }; auto* texture = arg(Usage::kTexture); assert(texture); auto* texture_type = TypeOf(texture)->UnwrapAll()->As(); switch (intrinsic->Type()) { case semantic::IntrinsicType::kTextureDimensions: { std::vector dims; switch (texture_type->dim()) { case type::TextureDimension::kNone: diagnostics_.add_error("texture dimension is kNone"); return false; case type::TextureDimension::k1d: dims = {"width"}; break; case type::TextureDimension::k2d: case type::TextureDimension::k2dArray: dims = {"width", "height"}; break; case type::TextureDimension::k3d: dims = {"width", "height", "depth"}; break; case type::TextureDimension::kCube: case type::TextureDimension::kCubeArray: // width == height == depth for cubes // See https://github.com/gpuweb/gpuweb/issues/1345 dims = {"width", "height", "height"}; break; } auto get_dim = [&](const char* name) { if (!EmitExpression(texture)) { return false; } out_ << ".get_" << name << "("; if (auto* level = arg(Usage::kLevel)) { if (!EmitExpression(level)) { return false; } } out_ << ")"; return true; }; if (dims.size() == 1) { out_ << "int("; get_dim(dims[0]); out_ << ")"; } else { EmitType(TypeOf(expr), ""); out_ << "("; for (size_t i = 0; i < dims.size(); i++) { if (i > 0) { out_ << ", "; } get_dim(dims[i]); } out_ << ")"; } return true; } case semantic::IntrinsicType::kTextureNumLayers: { out_ << "int("; if (!EmitExpression(texture)) { return false; } out_ << ".get_array_size())"; return true; } case semantic::IntrinsicType::kTextureNumLevels: { out_ << "int("; if (!EmitExpression(texture)) { return false; } out_ << ".get_num_mip_levels())"; return true; } case semantic::IntrinsicType::kTextureNumSamples: { out_ << "int("; if (!EmitExpression(texture)) { return false; } out_ << ".get_num_samples())"; return true; } default: break; } if (!EmitExpression(texture)) return false; bool lod_param_is_named = true; switch (intrinsic->Type()) { case semantic::IntrinsicType::kTextureSample: case semantic::IntrinsicType::kTextureSampleBias: case semantic::IntrinsicType::kTextureSampleLevel: case semantic::IntrinsicType::kTextureSampleGrad: out_ << ".sample("; break; case semantic::IntrinsicType::kTextureSampleCompare: out_ << ".sample_compare("; break; case semantic::IntrinsicType::kTextureLoad: out_ << ".read("; lod_param_is_named = false; break; case semantic::IntrinsicType::kTextureStore: out_ << ".write("; break; default: TINT_UNREACHABLE(diagnostics_) << "Unhandled texture intrinsic '" << intrinsic->str() << "'"; return false; } bool first_arg = true; auto maybe_write_comma = [&] { if (!first_arg) { out_ << ", "; } first_arg = false; }; for (auto usage : {Usage::kValue, Usage::kSampler, Usage::kCoords, Usage::kArrayIndex, Usage::kDepthRef, Usage::kSampleIndex}) { if (auto* e = arg(usage)) { maybe_write_comma(); if (!EmitExpression(e)) return false; } } if (auto* bias = arg(Usage::kBias)) { maybe_write_comma(); out_ << "bias("; if (!EmitExpression(bias)) { return false; } out_ << ")"; } if (auto* level = arg(Usage::kLevel)) { maybe_write_comma(); if (lod_param_is_named) { out_ << "level("; } if (!EmitExpression(level)) { return false; } if (lod_param_is_named) { out_ << ")"; } } if (auto* ddx = arg(Usage::kDdx)) { auto dim = texture_type->dim(); switch (dim) { case type::TextureDimension::k2d: case type::TextureDimension::k2dArray: maybe_write_comma(); out_ << "gradient2d("; break; case type::TextureDimension::k3d: maybe_write_comma(); out_ << "gradient3d("; break; case type::TextureDimension::kCube: case type::TextureDimension::kCubeArray: maybe_write_comma(); out_ << "gradientcube("; break; default: { std::stringstream err; err << "MSL does not support gradients for " << dim << " textures"; diagnostics_.add_error(err.str()); return false; } } if (!EmitExpression(ddx)) { return false; } out_ << ", "; if (!EmitExpression(arg(Usage::kDdy))) { return false; } out_ << ")"; } if (auto* offset = arg(Usage::kOffset)) { maybe_write_comma(); if (!EmitExpression(offset)) { return false; } } out_ << ")"; return true; } std::string GeneratorImpl::generate_builtin_name( const semantic::Intrinsic* intrinsic) { std::string out = ""; switch (intrinsic->Type()) { case semantic::IntrinsicType::kAcos: case semantic::IntrinsicType::kAll: case semantic::IntrinsicType::kAny: case semantic::IntrinsicType::kAsin: case semantic::IntrinsicType::kAtan: case semantic::IntrinsicType::kAtan2: case semantic::IntrinsicType::kCeil: case semantic::IntrinsicType::kCos: case semantic::IntrinsicType::kCosh: case semantic::IntrinsicType::kCross: case semantic::IntrinsicType::kDeterminant: case semantic::IntrinsicType::kDistance: case semantic::IntrinsicType::kDot: case semantic::IntrinsicType::kExp: case semantic::IntrinsicType::kExp2: case semantic::IntrinsicType::kFloor: case semantic::IntrinsicType::kFma: case semantic::IntrinsicType::kFract: case semantic::IntrinsicType::kLength: case semantic::IntrinsicType::kLdexp: case semantic::IntrinsicType::kLog: case semantic::IntrinsicType::kLog2: case semantic::IntrinsicType::kMix: case semantic::IntrinsicType::kNormalize: case semantic::IntrinsicType::kPow: case semantic::IntrinsicType::kReflect: case semantic::IntrinsicType::kSelect: case semantic::IntrinsicType::kSin: case semantic::IntrinsicType::kSinh: case semantic::IntrinsicType::kSqrt: case semantic::IntrinsicType::kStep: case semantic::IntrinsicType::kTan: case semantic::IntrinsicType::kTanh: case semantic::IntrinsicType::kTrunc: case semantic::IntrinsicType::kSign: case semantic::IntrinsicType::kClamp: out += intrinsic->str(); break; case semantic::IntrinsicType::kAbs: if (intrinsic->ReturnType()->is_float_scalar_or_vector()) { out += "fabs"; } else { out += "abs"; } break; case semantic::IntrinsicType::kCountOneBits: out += "popcount"; break; case semantic::IntrinsicType::kDpdx: case semantic::IntrinsicType::kDpdxCoarse: case semantic::IntrinsicType::kDpdxFine: out += "dfdx"; break; case semantic::IntrinsicType::kDpdy: case semantic::IntrinsicType::kDpdyCoarse: case semantic::IntrinsicType::kDpdyFine: out += "dfdy"; break; case semantic::IntrinsicType::kFwidth: case semantic::IntrinsicType::kFwidthCoarse: case semantic::IntrinsicType::kFwidthFine: out += "fwidth"; break; case semantic::IntrinsicType::kIsFinite: out += "isfinite"; break; case semantic::IntrinsicType::kIsInf: out += "isinf"; break; case semantic::IntrinsicType::kIsNan: out += "isnan"; break; case semantic::IntrinsicType::kIsNormal: out += "isnormal"; break; case semantic::IntrinsicType::kMax: if (intrinsic->ReturnType()->is_float_scalar_or_vector()) { out += "fmax"; } else { out += "max"; } break; case semantic::IntrinsicType::kMin: if (intrinsic->ReturnType()->is_float_scalar_or_vector()) { out += "fmin"; } else { out += "min"; } break; case semantic::IntrinsicType::kFaceForward: out += "faceforward"; break; case semantic::IntrinsicType::kPack4x8Snorm: out += "pack_float_to_snorm4x8"; break; case semantic::IntrinsicType::kPack4x8Unorm: out += "pack_float_to_unorm4x8"; break; case semantic::IntrinsicType::kPack2x16Snorm: out += "pack_float_to_snorm2x16"; break; case semantic::IntrinsicType::kPack2x16Unorm: out += "pack_float_to_unorm2x16"; break; case semantic::IntrinsicType::kReverseBits: out += "reverse_bits"; break; case semantic::IntrinsicType::kRound: out += "rint"; break; case semantic::IntrinsicType::kSmoothStep: out += "smoothstep"; break; case semantic::IntrinsicType::kInverseSqrt: out += "rsqrt"; break; case semantic::IntrinsicType::kUnpack4x8Snorm: out += "unpack_snorm4x8_to_float"; break; case semantic::IntrinsicType::kUnpack4x8Unorm: out += "unpack_unorm4x8_to_float"; break; case semantic::IntrinsicType::kUnpack2x16Snorm: out += "unpack_snorm2x16_to_float"; break; case semantic::IntrinsicType::kUnpack2x16Unorm: out += "unpack_unorm2x16_to_float"; break; default: diagnostics_.add_error("Unknown import method: " + std::string(intrinsic->str())); return ""; } return out; } bool GeneratorImpl::EmitCase(ast::CaseStatement* stmt) { make_indent(); if (stmt->IsDefault()) { out_ << "default:"; } else { bool first = true; for (auto* selector : stmt->selectors()) { if (!first) { out_ << std::endl; make_indent(); } first = false; out_ << "case "; if (!EmitLiteral(selector)) { return false; } out_ << ":"; } } out_ << " {" << std::endl; increment_indent(); for (auto* s : *stmt->body()) { if (!EmitStatement(s)) { return false; } } if (!last_is_break_or_fallthrough(stmt->body())) { make_indent(); out_ << "break;" << std::endl; } decrement_indent(); make_indent(); out_ << "}" << std::endl; return true; } bool GeneratorImpl::EmitConstructor(ast::ConstructorExpression* expr) { if (auto* scalar = expr->As()) { return EmitScalarConstructor(scalar); } return EmitTypeConstructor(expr->As()); } bool GeneratorImpl::EmitContinue(ast::ContinueStatement*) { make_indent(); out_ << "continue;" << std::endl; return true; } bool GeneratorImpl::EmitTypeConstructor(ast::TypeConstructorExpression* expr) { if (expr->type()->IsAnyOf()) { out_ << "{"; } else { if (!EmitType(expr->type(), "")) { return false; } out_ << "("; } // If the type constructor is empty then we need to construct with the zero // value for all components. if (expr->values().empty()) { if (!EmitZeroValue(expr->type())) { return false; } } else { bool first = true; for (auto* e : expr->values()) { if (!first) { out_ << ", "; } first = false; if (!EmitExpression(e)) { return false; } } } if (expr->type()->IsAnyOf()) { out_ << "}"; } else { out_ << ")"; } return true; } bool GeneratorImpl::EmitZeroValue(type::Type* type) { if (type->Is()) { out_ << "false"; } else if (type->Is()) { out_ << "0.0f"; } else if (type->Is()) { out_ << "0"; } else if (type->Is()) { out_ << "0u"; } else if (auto* vec = type->As()) { return EmitZeroValue(vec->type()); } else if (auto* mat = type->As()) { return EmitZeroValue(mat->type()); } else if (auto* arr = type->As()) { out_ << "{"; if (!EmitZeroValue(arr->type())) { return false; } out_ << "}"; } else if (type->As()) { out_ << "{}"; } else { diagnostics_.add_error("Invalid type for zero emission: " + type->type_name()); return false; } return true; } bool GeneratorImpl::EmitScalarConstructor( ast::ScalarConstructorExpression* expr) { return EmitLiteral(expr->literal()); } bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { if (auto* l = lit->As()) { out_ << (l->IsTrue() ? "true" : "false"); } else if (auto* fl = lit->As()) { out_ << FloatToString(fl->value()) << "f"; } else if (auto* sl = lit->As()) { out_ << sl->value(); } else if (auto* ul = lit->As()) { out_ << ul->value() << "u"; } else { diagnostics_.add_error("unknown literal type"); return false; } return true; } // TODO(jrprice): Remove this when we remove support for entry point params as // module-scope globals. bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); std::vector> in_locations; std::vector> out_variables; for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; auto* deco = data.second; if (var->StorageClass() == ast::StorageClass::kInput) { in_locations.push_back({var->Declaration(), deco->value()}); } else if (var->StorageClass() == ast::StorageClass::kOutput) { out_variables.push_back({var->Declaration(), deco}); } } for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; auto* deco = data.second; if (var->StorageClass() == ast::StorageClass::kOutput) { out_variables.push_back({var->Declaration(), deco}); } } if (!in_locations.empty()) { auto in_struct_name = program_->Symbols().NameFor(func->symbol()) + "_" + kInStructNameSuffix; auto* in_var_name = kTintStructInVarPrefix; ep_sym_to_in_data_[func->symbol()] = {in_struct_name, in_var_name}; make_indent(); out_ << "struct " << in_struct_name << " {" << std::endl; increment_indent(); for (auto& data : in_locations) { auto* var = data.first; uint32_t loc = data.second; make_indent(); if (!EmitType(program_->Sem().Get(var)->Type(), program_->Symbols().NameFor(var->symbol()))) { return false; } out_ << " " << program_->Symbols().NameFor(var->symbol()) << " [["; if (func->pipeline_stage() == ast::PipelineStage::kVertex) { out_ << "attribute(" << loc << ")"; } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) { out_ << "user(locn" << loc << ")"; } else { diagnostics_.add_error("invalid location variable for pipeline stage"); return false; } out_ << "]];" << std::endl; } decrement_indent(); make_indent(); out_ << "};" << std::endl << std::endl; } if (!out_variables.empty()) { auto out_struct_name = program_->Symbols().NameFor(func->symbol()) + "_" + kOutStructNameSuffix; auto* out_var_name = kTintStructOutVarPrefix; ep_sym_to_out_data_[func->symbol()] = {out_struct_name, out_var_name}; make_indent(); out_ << "struct " << out_struct_name << " {" << std::endl; increment_indent(); for (auto& data : out_variables) { auto* var = data.first; auto* deco = data.second; make_indent(); if (!EmitType(program_->Sem().Get(var)->Type(), program_->Symbols().NameFor(var->symbol()))) { return false; } out_ << " " << program_->Symbols().NameFor(var->symbol()) << " [["; if (auto* location = deco->As()) { auto loc = location->value(); if (func->pipeline_stage() == ast::PipelineStage::kVertex) { out_ << "user(locn" << loc << ")"; } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) { out_ << "color(" << loc << ")"; } else { diagnostics_.add_error( "invalid location variable for pipeline stage"); return false; } } else if (auto* builtin = deco->As()) { auto attr = builtin_to_attribute(builtin->value()); if (attr.empty()) { diagnostics_.add_error("unsupported builtin"); return false; } out_ << attr; } else { diagnostics_.add_error( "unsupported variable decoration for entry point output"); return false; } out_ << "]];" << std::endl; } decrement_indent(); make_indent(); out_ << "};" << std::endl << std::endl; } return true; } bool GeneratorImpl::EmitExpression(ast::Expression* expr) { if (auto* a = expr->As()) { return EmitArrayAccessor(a); } if (auto* b = expr->As()) { return EmitBinary(b); } if (auto* b = expr->As()) { return EmitBitcast(b); } if (auto* c = expr->As()) { return EmitCall(c); } if (auto* c = expr->As()) { return EmitConstructor(c); } if (auto* i = expr->As()) { return EmitIdentifier(i); } if (auto* m = expr->As()) { return EmitMemberAccessor(m); } if (auto* u = expr->As()) { return EmitUnaryOp(u); } diagnostics_.add_error("unknown expression type: " + program_->str(expr)); return false; } void GeneratorImpl::EmitStage(ast::PipelineStage stage) { switch (stage) { case ast::PipelineStage::kFragment: out_ << "fragment"; break; case ast::PipelineStage::kVertex: out_ << "vertex"; break; case ast::PipelineStage::kCompute: out_ << "kernel"; break; case ast::PipelineStage::kNone: break; } return; } bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; if (var->StorageClass() == ast::StorageClass::kInput) { return true; } } return false; } bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; if (var->StorageClass() == ast::StorageClass::kOutput) { return true; } } for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->StorageClass() == ast::StorageClass::kOutput) { return true; } } return false; } bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) { return has_referenced_in_var_needing_struct(func) || has_referenced_out_var_needing_struct(func); } bool GeneratorImpl::EmitFunction(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); make_indent(); // Entry points will be emitted later, skip for now. if (func->IsEntryPoint()) { return true; } // TODO(dsinclair): This could be smarter. If the input/outputs for multiple // entry points are the same we could generate a single struct and then have // this determine it's the same struct and just emit once. bool emit_duplicate_functions = func_sem->AncestorEntryPoints().size() > 0 && has_referenced_var_needing_struct(func); if (emit_duplicate_functions) { for (const auto& ep_sym : func_sem->AncestorEntryPoints()) { if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) { return false; } out_ << std::endl; } } else { // Emit as non-duplicated if (!EmitFunctionInternal(func, false, Symbol())) { return false; } out_ << std::endl; } return true; } bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, bool emit_duplicate_functions, Symbol ep_sym) { auto* func_sem = program_->Sem().Get(func); auto name = func->symbol().to_str(); if (!EmitType(func->return_type(), "")) { return false; } out_ << " "; if (emit_duplicate_functions) { auto func_name = name; auto ep_name = ep_sym.to_str(); name = program_->Symbols().NameFor(func->symbol()) + "_" + program_->Symbols().NameFor(ep_sym); ep_func_name_remapped_[ep_name + "_" + func_name] = name; } else { name = program_->Symbols().NameFor(func->symbol()); } out_ << name << "("; bool first = true; // If we're emitting duplicate functions that means the function takes // the stage_in or stage_out value from the entry point, emit them. // // We emit both of them if they're there regardless of if they're both used. if (emit_duplicate_functions) { auto in_it = ep_sym_to_in_data_.find(ep_sym); if (in_it != ep_sym_to_in_data_.end()) { out_ << "thread " << in_it->second.struct_name << "& " << in_it->second.var_name; first = false; } auto out_it = ep_sym_to_out_data_.find(ep_sym); if (out_it != ep_sym_to_out_data_.end()) { if (!first) { out_ << ", "; } out_ << "thread " << out_it->second.struct_name << "& " << out_it->second.var_name; first = false; } } for (const auto& data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->StorageClass() != ast::StorageClass::kInput) { continue; } if (!first) { out_ << ", "; } first = false; out_ << "thread "; if (!EmitType(var->Type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()); } for (const auto& data : func_sem->ReferencedUniformVariables()) { auto* var = data.first; if (!first) { out_ << ", "; } first = false; out_ << "constant "; // TODO(dsinclair): Can arrays be uniform? If so, fix this ... if (!EmitType(var->Type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()); } for (const auto& data : func_sem->ReferencedStorageBufferVariables()) { auto* var = data.first; if (!first) { out_ << ", "; } first = false; auto* ac = var->Type()->As(); if (ac == nullptr) { diagnostics_.add_error( "invalid type for storage buffer, expected access control"); return false; } if (ac->IsReadOnly()) { out_ << "const "; } out_ << "device "; if (!EmitType(ac->type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()); } for (auto* v : func->params()) { if (!first) { out_ << ", "; } first = false; auto* type = program_->Sem().Get(v)->Type(); if (!EmitType(type, program_->Symbols().NameFor(v->symbol()))) { return false; } // Array name is output as part of the type if (!type->Is()) { out_ << " " << program_->Symbols().NameFor(v->symbol()); } } out_ << ") "; current_ep_sym_ = ep_sym; if (!EmitBlockAndNewline(func->body())) { return false; } current_ep_sym_ = Symbol(); return true; } std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const { switch (builtin) { case ast::Builtin::kPosition: return "position"; case ast::Builtin::kVertexIndex: return "vertex_id"; case ast::Builtin::kInstanceIndex: return "instance_id"; case ast::Builtin::kFrontFacing: return "front_facing"; case ast::Builtin::kFragCoord: return "position"; case ast::Builtin::kFragDepth: return "depth(any)"; case ast::Builtin::kLocalInvocationId: return "thread_position_in_threadgroup"; case ast::Builtin::kLocalInvocationIndex: return "thread_index_in_threadgroup"; case ast::Builtin::kGlobalInvocationId: return "thread_position_in_grid"; case ast::Builtin::kSampleIndex: return "sample_id"; case ast::Builtin::kSampleMaskIn: return "sample_mask"; case ast::Builtin::kSampleMaskOut: return "sample_mask"; default: break; } return ""; } bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); make_indent(); current_ep_sym_ = func->symbol(); EmitStage(func->pipeline_stage()); out_ << " "; // This is an entry point, the return type is the entry point output structure // if one exists, or void otherwise. auto out_data = ep_sym_to_out_data_.find(current_ep_sym_); bool has_out_data = out_data != ep_sym_to_out_data_.end(); if (has_out_data) { out_ << out_data->second.struct_name; } else { out_ << "void"; } out_ << " " << program_->Symbols().NameFor(func->symbol()) << "("; bool first = true; // TODO(jrprice): Remove this when we remove support for builtins as globals. auto in_data = ep_sym_to_in_data_.find(current_ep_sym_); if (in_data != ep_sym_to_in_data_.end()) { out_ << in_data->second.struct_name << " " << in_data->second.var_name << " [[stage_in]]"; first = false; } // Emit entry point parameters. for (auto* var : func->params()) { if (!first) { out_ << ", "; } first = false; auto* type = program_->Sem().Get(var)->Type(); if (!EmitType(type, "")) { return false; } out_ << " " << program_->Symbols().NameFor(var->symbol()); if (type->Is()) { out_ << " [[stage_in]]"; } else { auto& decos = var->decorations(); bool builtin_found = false; for (auto* deco : decos) { auto* builtin = deco->As(); if (!builtin) { continue; } builtin_found = true; auto attr = builtin_to_attribute(builtin->value()); if (attr.empty()) { diagnostics_.add_error("unknown builtin"); return false; } out_ << " [[" << attr << "]]"; } if (!builtin_found) { TINT_ICE(diagnostics_) << "Unsupported entry point parameter"; } } } // TODO(jrprice): Remove this when we remove support for builtins as globals. for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->StorageClass() != ast::StorageClass::kInput) { continue; } if (!first) { out_ << ", "; } first = false; auto* builtin = data.second; if (!EmitType(var->Type(), "")) { return false; } auto attr = builtin_to_attribute(builtin->value()); if (attr.empty()) { diagnostics_.add_error("unknown builtin"); return false; } out_ << " " << program_->Symbols().NameFor(var->Declaration()->symbol()) << " [[" << attr << "]]"; } for (auto data : func_sem->ReferencedUniformVariables()) { if (!first) { out_ << ", "; } first = false; auto* var = data.first; // TODO(dsinclair): We're using the binding to make up the buffer number but // we should instead be using a provided mapping that uses both buffer and // set. https://bugs.chromium.org/p/tint/issues/detail?id=104 auto* binding = data.second.binding; if (binding == nullptr) { diagnostics_.add_error( "unable to find binding information for uniform: " + program_->Symbols().NameFor(var->Declaration()->symbol())); return false; } // auto* set = data.second.set; out_ << "constant "; // TODO(dsinclair): Can you have a uniform array? If so, this needs to be // updated to handle arrays property. if (!EmitType(var->Type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()) << " [[buffer(" << binding->value() << ")]]"; } for (auto data : func_sem->ReferencedStorageBufferVariables()) { if (!first) { out_ << ", "; } first = false; auto* var = data.first; // TODO(dsinclair): We're using the binding to make up the buffer number but // we should instead be using a provided mapping that uses both buffer and // set. https://bugs.chromium.org/p/tint/issues/detail?id=104 auto* binding = data.second.binding; // auto* set = data.second.set; auto* ac = var->Type()->As(); if (ac == nullptr) { diagnostics_.add_error( "invalid type for storage buffer, expected access control"); return false; } if (ac->IsReadOnly()) { out_ << "const "; } out_ << "device "; if (!EmitType(ac->type(), "")) { return false; } out_ << "& " << program_->Symbols().NameFor(var->Declaration()->symbol()) << " [[buffer(" << binding->value() << ")]]"; } out_ << ") {" << std::endl; increment_indent(); if (has_out_data) { make_indent(); out_ << out_data->second.struct_name << " " << out_data->second.var_name << " = {};" << std::endl; } generating_entry_point_ = true; for (auto* s : *func->body()) { if (!EmitStatement(s)) { return false; } } auto* last_statement = func->get_last_statement(); if (last_statement == nullptr || !last_statement->Is()) { ast::ReturnStatement ret(Source{}); if (!EmitStatement(&ret)) { return false; } } generating_entry_point_ = false; decrement_indent(); make_indent(); out_ << "}" << std::endl; current_ep_sym_ = Symbol(); return true; } bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const { bool in_or_out_struct_has_location = var != nullptr && var->Declaration()->HasLocationDecoration() && (var->StorageClass() == ast::StorageClass::kInput || var->StorageClass() == ast::StorageClass::kOutput); bool in_struct_has_builtin = var != nullptr && var->Declaration()->HasBuiltinDecoration() && var->StorageClass() == ast::StorageClass::kOutput; return in_or_out_struct_has_location || in_struct_has_builtin; } bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) { auto* ident = expr->As(); const semantic::Variable* var = nullptr; if (global_variables_.get(ident->symbol(), &var)) { if (global_is_in_struct(var)) { auto var_type = var->StorageClass() == ast::StorageClass::kInput ? VarType::kIn : VarType::kOut; auto name = current_ep_var_name(var_type); if (name.empty()) { diagnostics_.add_error("unable to find entry point data for variable"); return false; } out_ << name << "."; } } out_ << program_->Symbols().NameFor(ident->symbol()); return true; } bool GeneratorImpl::EmitLoop(ast::LoopStatement* stmt) { loop_emission_counter_++; std::string guard = "tint_msl_is_first_" + std::to_string(loop_emission_counter_); if (stmt->has_continuing()) { make_indent(); // Continuing variables get their own scope. out_ << "{" << std::endl; increment_indent(); make_indent(); out_ << "bool " << guard << " = true;" << std::endl; // A continuing block may use variables declared in the method body. As a // first pass, if we have a continuing, we pull all declarations outside // the for loop into the continuing scope. Then, the variable declarations // will be turned into assignments. for (auto* s : *(stmt->body())) { if (auto* decl = s->As()) { if (!EmitVariable(program_->Sem().Get(decl->variable()), true)) { return false; } } } } make_indent(); out_ << "for(;;) {" << std::endl; increment_indent(); if (stmt->has_continuing()) { make_indent(); out_ << "if (!" << guard << ") "; if (!EmitBlockAndNewline(stmt->continuing())) { return false; } make_indent(); out_ << guard << " = false;" << std::endl; out_ << std::endl; } for (auto* s : *(stmt->body())) { // If we have a continuing block we've already emitted the variable // declaration before the loop, so treat it as an assignment. auto* decl = s->As(); if (decl != nullptr && stmt->has_continuing()) { make_indent(); auto* var = decl->variable(); out_ << program_->Symbols().NameFor(var->symbol()) << " = "; if (var->constructor() != nullptr) { if (!EmitExpression(var->constructor())) { return false; } } else { if (!EmitZeroValue(program_->Sem().Get(var)->Type())) { return false; } } out_ << ";" << std::endl; continue; } if (!EmitStatement(s)) { return false; } } decrement_indent(); make_indent(); out_ << "}" << std::endl; // Close the scope for any continuing variables. if (stmt->has_continuing()) { decrement_indent(); make_indent(); out_ << "}" << std::endl; } return true; } bool GeneratorImpl::EmitDiscard(ast::DiscardStatement*) { make_indent(); // TODO(dsinclair): Verify this is correct when the discard semantics are // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361) out_ << "discard_fragment();" << std::endl; return true; } bool GeneratorImpl::EmitElse(ast::ElseStatement* stmt) { if (stmt->HasCondition()) { out_ << " else if ("; if (!EmitExpression(stmt->condition())) { return false; } out_ << ") "; } else { out_ << " else "; } return EmitBlock(stmt->body()); } bool GeneratorImpl::EmitIf(ast::IfStatement* stmt) { make_indent(); out_ << "if ("; if (!EmitExpression(stmt->condition())) { return false; } out_ << ") "; if (!EmitBlock(stmt->body())) { return false; } for (auto* e : stmt->else_statements()) { if (!EmitElse(e)) { return false; } } out_ << std::endl; return true; } bool GeneratorImpl::EmitMemberAccessor(ast::MemberAccessorExpression* expr) { if (!EmitExpression(expr->structure())) { return false; } out_ << "."; // Swizzles get written out directly if (program_->Sem().Get(expr)->IsSwizzle()) { out_ << program_->Symbols().NameFor(expr->member()->symbol()); } else if (!EmitExpression(expr->member())) { return false; } return true; } bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) { make_indent(); out_ << "return"; if (generating_entry_point_) { auto out_data = ep_sym_to_out_data_.find(current_ep_sym_); if (out_data != ep_sym_to_out_data_.end()) { out_ << " " << out_data->second.var_name; } } else if (stmt->has_value()) { out_ << " "; if (!EmitExpression(stmt->value())) { return false; } } out_ << ";" << std::endl; return true; } bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) { out_ << "{" << std::endl; increment_indent(); for (auto* s : *stmt) { if (!EmitStatement(s)) { return false; } } decrement_indent(); make_indent(); out_ << "}"; return true; } bool GeneratorImpl::EmitBlockAndNewline(const ast::BlockStatement* stmt) { const bool result = EmitBlock(stmt); if (result) { out_ << std::endl; } return result; } bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) { make_indent(); const bool result = EmitBlock(stmt); if (result) { out_ << std::endl; } return result; } bool GeneratorImpl::EmitStatement(ast::Statement* stmt) { if (auto* a = stmt->As()) { return EmitAssign(a); } if (auto* b = stmt->As()) { return EmitIndentedBlockAndNewline(b); } if (auto* b = stmt->As()) { return EmitBreak(b); } if (auto* c = stmt->As()) { make_indent(); if (!EmitCall(c->expr())) { return false; } out_ << ";" << std::endl; return true; } if (auto* c = stmt->As()) { return EmitContinue(c); } if (auto* d = stmt->As()) { return EmitDiscard(d); } if (stmt->As()) { make_indent(); out_ << "/* fallthrough */" << std::endl; return true; } if (auto* i = stmt->As()) { return EmitIf(i); } if (auto* l = stmt->As()) { return EmitLoop(l); } if (auto* r = stmt->As()) { return EmitReturn(r); } if (auto* s = stmt->As()) { return EmitSwitch(s); } if (auto* v = stmt->As()) { auto* var = program_->Sem().Get(v->variable()); return EmitVariable(var, false); } diagnostics_.add_error("unknown statement type: " + program_->str(stmt)); return false; } bool GeneratorImpl::EmitSwitch(ast::SwitchStatement* stmt) { make_indent(); out_ << "switch("; if (!EmitExpression(stmt->condition())) { return false; } out_ << ") {" << std::endl; increment_indent(); for (auto* s : stmt->body()) { if (!EmitCase(s)) { return false; } } decrement_indent(); make_indent(); out_ << "}" << std::endl; return true; } bool GeneratorImpl::EmitType(type::Type* type, const std::string& name) { std::string access_str = ""; if (auto* ac = type->As()) { if (ac->access_control() == ast::AccessControl::kReadOnly) { access_str = "read"; } else if (ac->access_control() == ast::AccessControl::kWriteOnly) { access_str = "write"; } else { diagnostics_.add_error("Invalid access control for storage texture"); return false; } type = ac->type(); } if (auto* alias = type->As()) { out_ << program_->Symbols().NameFor(alias->symbol()); } else if (auto* ary = type->As()) { type::Type* base_type = ary; std::vector sizes; while (auto* arr = base_type->As()) { if (arr->IsRuntimeArray()) { sizes.push_back(1); } else { sizes.push_back(arr->size()); } base_type = arr->type(); } if (!EmitType(base_type, "")) { return false; } if (!name.empty()) { out_ << " " << name; } for (uint32_t size : sizes) { out_ << "[" << size << "]"; } } else if (type->Is()) { out_ << "bool"; } else if (type->Is()) { out_ << "float"; } else if (type->Is()) { out_ << "int"; } else if (auto* mat = type->As()) { if (!EmitType(mat->type(), "")) { return false; } out_ << mat->columns() << "x" << mat->rows(); } else if (auto* ptr = type->As()) { // TODO(dsinclair): Storage class? if (!EmitType(ptr->type(), "")) { return false; } out_ << "*"; } else if (type->Is()) { out_ << "sampler"; } else if (auto* str = type->As()) { // The struct type emits as just the name. The declaration would be emitted // as part of emitting the constructed types. out_ << program_->Symbols().NameFor(str->symbol()); } else if (auto* tex = type->As()) { if (tex->Is()) { out_ << "depth"; } else { out_ << "texture"; } switch (tex->dim()) { case type::TextureDimension::k1d: out_ << "1d"; break; case type::TextureDimension::k2d: out_ << "2d"; break; case type::TextureDimension::k2dArray: out_ << "2d_array"; break; case type::TextureDimension::k3d: out_ << "3d"; break; case type::TextureDimension::kCube: out_ << "cube"; break; case type::TextureDimension::kCubeArray: out_ << "cube_array"; break; default: diagnostics_.add_error("Invalid texture dimensions"); return false; } if (tex->Is()) { out_ << "_ms"; } out_ << "<"; if (tex->Is()) { out_ << "float, access::sample"; } else if (auto* storage = tex->As()) { if (!EmitType(storage->type(), "")) { return false; } out_ << ", access::" << access_str; } else if (auto* ms = tex->As()) { if (!EmitType(ms->type(), "")) { return false; } out_ << ", access::sample"; } else if (auto* sampled = tex->As()) { if (!EmitType(sampled->type(), "")) { return false; } out_ << ", access::sample"; } else { diagnostics_.add_error("invalid texture type"); return false; } out_ << ">"; } else if (type->Is()) { out_ << "uint"; } else if (auto* vec = type->As()) { if (!EmitType(vec->type(), "")) { return false; } out_ << vec->size(); } else if (type->Is()) { out_ << "void"; } else { diagnostics_.add_error("unknown type in EmitType: " + type->type_name()); return false; } return true; } bool GeneratorImpl::EmitPackedType(type::Type* type, const std::string& name) { if (auto* alias = type->As()) { return EmitPackedType(alias->type(), name); } if (auto* vec = type->As()) { out_ << "packed_"; if (!EmitType(vec->type(), "")) { return false; } out_ << vec->size(); return true; } return EmitType(type, name); } bool GeneratorImpl::EmitStructType(const type::Struct* str) { // TODO(dsinclair): Block decoration? // if (str->impl()->decoration() != ast::Decoration::kNone) { // } out_ << "struct " << program_->Symbols().NameFor(str->symbol()) << " {" << std::endl; auto* sem_str = program_->Sem().Get(str); if (!sem_str) { TINT_ICE(diagnostics_) << "struct missing semantic info"; return false; } bool is_host_shareable = sem_str->IsHostShareable(); // Emits a `/* 0xnnnn */` byte offset comment for a struct member. auto add_byte_offset_comment = [&](uint32_t offset) { std::ios_base::fmtflags saved_flag_state(out_.flags()); out_ << "/* 0x" << std::hex << std::setfill('0') << std::setw(4) << offset << " */ "; out_.flags(saved_flag_state); }; uint32_t pad_count = 0; auto add_padding = [&](uint32_t size) { out_ << "int8_t _tint_pad_" << pad_count << "[" << size << "];" << std::endl; pad_count++; }; increment_indent(); uint32_t msl_offset = 0; for (auto* mem : str->impl()->members()) { make_indent(); auto* sem_mem = program_->Sem().Get(mem); if (!sem_mem) { TINT_ICE(diagnostics_) << "struct member missing semantic info"; return false; } auto wgsl_offset = sem_mem->Offset(); if (is_host_shareable) { if (wgsl_offset < msl_offset) { // Unimplementable layout TINT_ICE(diagnostics_) << "Structure member WGSL offset (" << wgsl_offset << ") is behind MSL offset (" << msl_offset << ")"; return false; } // Generate padding if required if (auto padding = wgsl_offset - msl_offset) { add_byte_offset_comment(msl_offset); add_padding(padding); msl_offset += padding; make_indent(); } add_byte_offset_comment(msl_offset); if (!EmitPackedType(mem->type(), program_->Symbols().NameFor(mem->symbol()))) { return false; } } else { if (!EmitType(mem->type(), program_->Symbols().NameFor(mem->symbol()))) { return false; } } auto* ty = mem->type()->UnwrapAliasIfNeeded(); // Array member name will be output with the type if (!ty->Is()) { out_ << " " << program_->Symbols().NameFor(mem->symbol()); } // Emit decorations for (auto* deco : mem->decorations()) { if (auto* loc = deco->As()) { out_ << " [[user(locn" + std::to_string(loc->value()) + ")]]"; } } out_ << ";" << std::endl; if (is_host_shareable) { // Calculate new MSL offset auto size_align = MslPackedTypeSizeAndAlign(ty); if (msl_offset % size_align.align) { TINT_ICE(diagnostics_) << "Misaligned MSL structure member " << ty->FriendlyName(program_->Symbols()) << " " << program_->Symbols().NameFor(mem->symbol()); return false; } msl_offset += size_align.size; } } if (is_host_shareable && sem_str->Size() != msl_offset) { make_indent(); add_byte_offset_comment(msl_offset); add_padding(sem_str->Size() - msl_offset); } decrement_indent(); make_indent(); out_ << "};" << std::endl; return true; } bool GeneratorImpl::EmitUnaryOp(ast::UnaryOpExpression* expr) { switch (expr->op()) { case ast::UnaryOp::kNot: out_ << "!"; break; case ast::UnaryOp::kNegation: out_ << "-"; break; } out_ << "("; if (!EmitExpression(expr->expr())) { return false; } out_ << ")"; return true; } bool GeneratorImpl::EmitVariable(const semantic::Variable* var, bool skip_constructor) { make_indent(); auto* decl = var->Declaration(); // TODO(dsinclair): Handle variable decorations if (!decl->decorations().empty()) { diagnostics_.add_error("Variable decorations are not handled yet"); return false; } if (decl->is_const()) { out_ << "const "; } if (!EmitType(var->Type(), program_->Symbols().NameFor(decl->symbol()))) { return false; } if (!var->Type()->Is()) { out_ << " " << program_->Symbols().NameFor(decl->symbol()); } if (!skip_constructor) { out_ << " = "; if (decl->constructor() != nullptr) { if (!EmitExpression(decl->constructor())) { return false; } } else if (var->StorageClass() == ast::StorageClass::kPrivate || var->StorageClass() == ast::StorageClass::kFunction || var->StorageClass() == ast::StorageClass::kNone || var->StorageClass() == ast::StorageClass::kOutput) { if (!EmitZeroValue(var->Type())) { return false; } } } out_ << ";" << std::endl; return true; } bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { make_indent(); for (auto* d : var->decorations()) { if (!d->Is()) { diagnostics_.add_error("Decorated const values not valid"); return false; } } if (!var->is_const()) { diagnostics_.add_error("Expected a const value"); return false; } out_ << "constant "; auto* type = program_->Sem().Get(var)->Type(); if (!EmitType(type, program_->Symbols().NameFor(var->symbol()))) { return false; } if (!type->Is()) { out_ << " " << program_->Symbols().NameFor(var->symbol()); } if (var->HasConstantIdDecoration()) { out_ << " [[function_constant(" << var->constant_id() << ")]]"; } else if (var->constructor() != nullptr) { out_ << " = "; if (!EmitExpression(var->constructor())) { return false; } } out_ << ";" << std::endl; return true; } GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign( type::Type* ty) { ty = ty->UnwrapAliasIfNeeded(); if (ty->IsAnyOf()) { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // 2.1 Scalar Data Types return {4, 4}; } if (auto* vec = ty->As()) { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // 2.2.3 Packed Vector Types auto num_els = vec->size(); auto* el_ty = vec->type()->UnwrapAll(); if (el_ty->IsAnyOf()) { return SizeAndAlign{num_els * 4, 4}; } } if (auto* mat = ty->As()) { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // 2.3 Matrix Data Types auto cols = mat->columns(); auto rows = mat->rows(); auto* el_ty = mat->type()->UnwrapAll(); if (el_ty->IsAnyOf()) { static constexpr SizeAndAlign table[] = { /* float2x2 */ {16, 8}, /* float2x3 */ {32, 16}, /* float2x4 */ {32, 16}, /* float3x2 */ {24, 8}, /* float3x3 */ {48, 16}, /* float3x4 */ {48, 16}, /* float4x2 */ {32, 8}, /* float4x3 */ {64, 16}, /* float4x4 */ {64, 16}, }; if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) { return table[(3 * (cols - 2)) + (rows - 2)]; } } } if (auto* arr = ty->As()) { auto* sem = program_->Sem().Get(arr); if (!sem) { TINT_ICE(diagnostics_) << "Array missing semantic info"; return {}; } auto el_size_align = MslPackedTypeSizeAndAlign(arr->type()); if (sem->Stride() != el_size_align.size) { // TODO(crbug.com/tint/649): transform::Msl needs to replace these arrays // with a new array type that has the element type padded to the required // stride. TINT_UNIMPLEMENTED(diagnostics_) << "Arrays with custom strides not yet implemented"; return {}; } auto num_els = std::max(arr->size(), 1); return SizeAndAlign{el_size_align.size * num_els, el_size_align.align}; } if (auto* str = ty->As()) { // TODO(crbug.com/tint/650): There's an assumption here that MSL's default // structure size and alignment matches WGSL's. We need to confirm this. auto* sem = program_->Sem().Get(str); if (!sem) { TINT_ICE(diagnostics_) << "Array missing semantic info"; return {}; } return SizeAndAlign{sem->Size(), sem->Align()}; } TINT_UNREACHABLE(diagnostics_) << "Unhandled type " << ty->TypeInfo().name; return {}; } } // namespace msl } // namespace writer } // namespace tint