resolver: Validate pipline stage use for intrinsics

Use the new [[stage()]] decorations in intrinsics.def to validate that intrinsics are only called from the correct pipeline stages.

Fixed: tint:657
Change-Id: I9efda26369c45c6f816bdaa53408d3909db403a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53084
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-06-03 16:07:34 +00:00 committed by Tint LUCI CQ
parent 7b366475ed
commit 71786c99b3
21 changed files with 611 additions and 148 deletions

View File

@ -511,6 +511,7 @@ libtint_source_set("libtint_core_all_src") {
"sem/node.h",
"sem/parameter_usage.cc",
"sem/parameter_usage.h",
"sem/pipeline_stage_set.h",
"sem/pointer_type.cc",
"sem/pointer_type.h",
"sem/reference_type.cc",

View File

@ -254,6 +254,7 @@ set(TINT_LIB_SRCS
sem/member_accessor_expression.cc
sem/parameter_usage.cc
sem/parameter_usage.h
sem/pipeline_stage_set.h
sem/node.cc
sem/node.h
sem/statement.cc
@ -580,6 +581,7 @@ if(${TINT_BUILD_TESTS})
resolver/function_validation_test.cc
resolver/host_shareable_validation_test.cc
resolver/intrinsic_test.cc
resolver/intrinsic_validation_test.cc
resolver/is_host_shareable_test.cc
resolver/is_storeable_test.cc
resolver/ptr_ref_test.cc

View File

@ -23,6 +23,7 @@
#include "src/sem/depth_texture_type.h"
#include "src/sem/external_texture_type.h"
#include "src/sem/multisampled_texture_type.h"
#include "src/sem/pipeline_stage_set.h"
#include "src/sem/sampled_texture_type.h"
#include "src/sem/storage_texture_type.h"
#include "src/utils/scoped_assignment.h"
@ -288,6 +289,8 @@ using TexelFormat = ast::ImageFormat;
using AccessControl = ast::AccessControl::Access;
using StorageClass = ast::StorageClass;
using ParameterUsage = sem::ParameterUsage;
using PipelineStageSet = sem::PipelineStageSet;
using PipelineStage = ast::PipelineStage;
bool match_bool(const sem::Type* ty) {
return ty->IsAnyOf<Any, sem::Bool>();
@ -608,6 +611,66 @@ const sem::ExternalTexture* build_texture_external(MatchState& state) {
return state.builder.create<sem::ExternalTexture>();
}
/// ParameterInfo describes a parameter
struct ParameterInfo {
/// The parameter usage (parameter name in definition file)
ParameterUsage const usage;
/// Pointer to a list of indices that are used to match the parameter type.
/// The matcher indices index on Matchers::type and / or Matchers::number.
/// These indices are consumed by the matchers themselves.
/// The first index is always a TypeMatcher.
MatcherIndex const* const matcher_indices;
};
/// OpenTypeInfo describes an open type
struct OpenTypeInfo {
/// Name of the open type (e.g. 'T')
const char* name;
/// Optional type matcher constraint.
/// Either an index in Matchers::type, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OpenNumberInfo describes an open number
struct OpenNumberInfo {
/// Name of the open number (e.g. 'N')
const char* name;
/// Optional number matcher constraint.
/// Either an index in Matchers::number, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OverloadInfo describes a single function overload
struct OverloadInfo {
/// Total number of parameters for the overload
uint8_t const num_parameters;
/// Total number of open types for the overload
uint8_t const num_open_types;
/// Total number of open numbers for the overload
uint8_t const num_open_numbers;
/// Pointer to the first open type
OpenTypeInfo const* const open_types;
/// Pointer to the first open number
OpenNumberInfo const* const open_numbers;
/// Pointer to the first parameter
ParameterInfo const* const parameters;
/// Pointer to a list of matcher indices that index on Matchers::type and
/// Matchers::number, used to build the return type. If the function has no
/// return type then this is null.
MatcherIndex const* const return_matcher_indices;
/// The pipeline stages that this overload can be used in.
PipelineStageSet supported_stages;
};
/// IntrinsicInfo describes an intrinsic function
struct IntrinsicInfo {
/// Number of overloads of the intrinsic function
uint8_t const num_overloads;
/// Pointer to the start of the overloads for the function
OverloadInfo const* const overloads;
};
#include "intrinsic_table.inl"
/// Impl is the private implementation of the IntrinsicTable interface.
@ -807,9 +870,9 @@ const sem::Intrinsic* Impl::Match(sem::IntrinsicType intrinsic_type,
return_type = builder.create<sem::Void>();
}
return builder.create<sem::Intrinsic>(intrinsic_type,
const_cast<sem::Type*>(return_type),
std::move(parameters));
return builder.create<sem::Intrinsic>(
intrinsic_type, const_cast<sem::Type*>(return_type),
std::move(parameters), overload.supported_stages);
}
MatchState Impl::Match(ClosedState& closed,

File diff suppressed because it is too large Load Diff

View File

@ -11,64 +11,6 @@ See:
// clang-format off
/// ParameterInfo describes a parameter
struct ParameterInfo {
/// The parameter usage (parameter name in definition file)
ParameterUsage const usage;
/// Pointer to a list of indices that are used to match the parameter type.
/// The matcher indices index on Matchers::type and / or Matchers::number.
/// These indices are consumed by the matchers themselves.
/// The first index is always a TypeMatcher.
MatcherIndex const* const matcher_indices;
};
/// OpenTypeInfo describes an open type
struct OpenTypeInfo {
/// Name of the open type (e.g. 'T')
const char* name;
/// Optional type matcher constraint.
/// Either an index in Matchers::type, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OpenNumberInfo describes an open number
struct OpenNumberInfo {
/// Name of the open number (e.g. 'N')
const char* name;
/// Optional number matcher constraint.
/// Either an index in Matchers::number, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OverloadInfo describes a single function overload
struct OverloadInfo {
/// Total number of parameters for the overload
uint8_t const num_parameters;
/// Total number of open types for the overload
uint8_t const num_open_types;
/// Total number of open numbers for the overload
uint8_t const num_open_numbers;
/// Pointer to the first open type
OpenTypeInfo const* const open_types;
/// Pointer to the first open number
OpenNumberInfo const* const open_numbers;
/// Pointer to the first parameter
ParameterInfo const* const parameters;
/// Pointer to a list of matcher indices that index on Matchers::type and
/// Matchers::number, used to build the return type. If the function has no
/// return type then this is null.
MatcherIndex const* const return_matcher_indices;
};
/// IntrinsicInfo describes an intrinsic function
struct IntrinsicInfo {
/// Number of overloads of the intrinsic function
uint8_t const num_overloads;
/// Pointer to the start of the overloads for the function
OverloadInfo const* const overloads;
};
{{ with .Sem -}}
{{ range .Types -}}
{{ template "Type" . }}
@ -155,6 +97,10 @@ constexpr OverloadInfo kOverloads[] = {
{{- if $o.ReturnMatcherIndicesOffset }} &kMatcherIndices[{{$o.ReturnMatcherIndicesOffset}}]
{{- else }} nullptr
{{- end }},
/* supported_stages */ PipelineStageSet(
{{- range $i, $u := $o.CanBeUsedInStage.List -}}
{{- if $i -}}, {{end}}PipelineStage::k{{Title $u}}
{{- end }}),
},
{{- end }}
};

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/call_statement.h"
#include "src/resolver/resolver_test_helper.h"
namespace tint {
@ -283,7 +284,8 @@ TEST_P(FloatAllMatching, Scalar) {
params.push_back(Expr(1.0f));
}
auto* builtin = Call(name, params);
WrapInFunction(builtin);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->Is<sem::F32>());
@ -298,7 +300,8 @@ TEST_P(FloatAllMatching, Vec2) {
params.push_back(vec2<f32>(1.0f, 1.0f));
}
auto* builtin = Call(name, params);
WrapInFunction(builtin);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
@ -313,7 +316,8 @@ TEST_P(FloatAllMatching, Vec3) {
params.push_back(vec3<f32>(1.0f, 1.0f, 1.0f));
}
auto* builtin = Call(name, params);
WrapInFunction(builtin);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
@ -328,7 +332,8 @@ TEST_P(FloatAllMatching, Vec4) {
params.push_back(vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
}
auto* builtin = Call(name, params);
WrapInFunction(builtin);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->is_float_vector());

View File

@ -55,7 +55,8 @@ TEST_P(ResolverIntrinsicDerivativeTest, Scalar) {
Global("ident", ty.f32(), ast::StorageClass::kInput);
auto* expr = Call(name, "ident");
WrapInFunction(expr);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -68,7 +69,8 @@ TEST_P(ResolverIntrinsicDerivativeTest, Vector) {
Global("ident", ty.vec4<f32>(), ast::StorageClass::kInput);
auto* expr = Call(name, "ident");
WrapInFunction(expr);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -1927,7 +1929,8 @@ TEST_P(ResolverIntrinsicTest_Texture, Call) {
param.buildSamplerVariable(this);
auto* call = Call(param.function, param.args(this));
WrapInFunction(call);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error();

View File

@ -0,0 +1,97 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/resolver/resolver.h"
#include "gmock/gmock.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic_texture_helper_test.h"
#include "src/ast/loop_statement.h"
#include "src/ast/return_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/switch_statement.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/resolver/resolver_test_helper.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/member_accessor_expression.h"
#include "src/sem/sampled_texture_type.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
using ::testing::ElementsAre;
using ::testing::HasSubstr;
namespace tint {
namespace resolver {
namespace {
using IntrinsicType = sem::IntrinsicType;
using ResolverIntrinsicValidationTest = ResolverTest;
TEST_F(ResolverIntrinsicValidationTest, InvalidPipelineStageDirect) {
// [[stage(compute)]] fn func { return dpdx(1.0); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
ast::ExpressionList{Expr(1.0f)});
Func(Source{{1, 2}}, "func", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(dpdx)},
{Stage(ast::PipelineStage::kCompute)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"3:4 error: built-in cannot be used by compute pipeline stage");
}
TEST_F(ResolverIntrinsicValidationTest, InvalidPipelineStageIndirect) {
// fn f0 { return dpdx(1.0); }
// fn f1 { f0(); }
// fn f2 { f1(); }
// [[stage(compute)]] fn main { return f2(); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
ast::ExpressionList{Expr(1.0f)});
Func(Source{{1, 2}}, "f0", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(dpdx)});
Func(Source{{3, 4}}, "f1", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(Call("f0"))});
Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(Call("f1"))});
Func(Source{{7, 8}}, "main", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(Call("f2"))},
{Stage(ast::PipelineStage::kCompute)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(3:4 error: built-in cannot be used by compute pipeline stage
1:2 note: called by function 'f0'
3:4 note: called by function 'f1'
5:6 note: called by function 'f2'
7:8 note: called by entry point 'main')");
}
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -237,6 +237,10 @@ bool Resolver::ResolveInternal() {
}
}
if (!ValidatePipelineStages()) {
return false;
}
bool result = true;
for (auto* node : builder_->ASTNodes().Objects()) {
@ -1129,6 +1133,10 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
bool Resolver::Function(ast::Function* func) {
auto* info = function_infos_.Create<FunctionInfo>(func);
if (func->IsEntryPoint()) {
entry_points_.emplace_back(info);
}
TINT_SCOPED_ASSIGNMENT(current_function_, info);
variable_stack_.push_scope();
@ -1707,6 +1715,10 @@ bool Resolver::IntrinsicCall(ast::CallExpression* call,
builder_->Sem().Add(
call, builder_->create<sem::Call>(call, result, current_statement_));
SetType(call, result->ReturnType());
current_function_->intrinsic_calls.emplace_back(
IntrinsicCallInfo{call, result});
return true;
}
@ -2460,25 +2472,76 @@ void Resolver::SetType(ast::Expression* expr,
expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_});
}
bool Resolver::ValidatePipelineStages() {
auto check_intrinsic_calls = [&](FunctionInfo* func,
FunctionInfo* entry_point) {
auto stage = entry_point->declaration->pipeline_stage();
for (auto& call : func->intrinsic_calls) {
if (!call.intrinsic->SupportedStages().Contains(stage)) {
std::stringstream err;
err << "built-in cannot be used by " << stage << " pipeline stage";
diagnostics_.add_error(err.str(), call.call->source());
if (func != entry_point) {
TraverseCallChain(entry_point, func, [&](FunctionInfo* f) {
diagnostics_.add_note(
"called by function '" +
builder_->Symbols().NameFor(f->declaration->symbol()) + "'",
f->declaration->source());
});
diagnostics_.add_note("called by entry point '" +
builder_->Symbols().NameFor(
entry_point->declaration->symbol()) +
"'",
entry_point->declaration->source());
}
return false;
}
}
return true;
};
for (auto* entry_point : entry_points_) {
if (!check_intrinsic_calls(entry_point, entry_point)) {
return false;
}
for (auto* func : entry_point->transitive_calls) {
if (!check_intrinsic_calls(func, entry_point)) {
return false;
}
}
}
return true;
}
template <typename CALLBACK>
void Resolver::TraverseCallChain(FunctionInfo* from,
FunctionInfo* to,
CALLBACK&& callback) const {
for (auto* f : from->transitive_calls) {
if (f == to) {
callback(f);
return;
}
if (f->transitive_calls.contains(to)) {
TraverseCallChain(f, to, callback);
callback(f);
return;
}
}
TINT_ICE(diagnostics_)
<< "TraverseCallChain() 'from' does not transitively call 'to'";
}
void Resolver::CreateSemanticNodes() const {
auto& sem = builder_->Sem();
// Collate all the 'ancestor_entry_points' - this is a map of function
// symbol to all the entry points that transitively call the function.
std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
for (auto* func : builder_->AST().Functions()) {
auto it = function_to_info_.find(func);
if (it == function_to_info_.end()) {
continue; // Resolver has likely errored. Process what we can.
}
auto* info = it->second;
if (!func->IsEntryPoint()) {
continue;
}
for (auto* call : info->transitive_calls) {
for (auto* entry_point : entry_points_) {
for (auto* call : entry_point->transitive_calls) {
auto& vec = ancestor_entry_points[call->declaration->symbol()];
vec.emplace_back(func->symbol());
vec.emplace_back(entry_point->declaration->symbol());
}
}

View File

@ -51,6 +51,7 @@ class Variable;
} // namespace ast
namespace sem {
class Array;
class Intrinsic;
class Statement;
} // namespace sem
@ -100,6 +101,11 @@ class Resolver {
sem::BindingPoint binding_point;
};
struct IntrinsicCallInfo {
const ast::CallExpression* call;
const sem::Intrinsic* intrinsic;
};
/// Structure holding semantic information about a function.
/// Used to build the sem::Function nodes at the end of resolving.
struct FunctionInfo {
@ -115,9 +121,13 @@ class Resolver {
sem::Type* return_type = nullptr;
std::string return_type_name;
std::array<sem::WorkgroupDimension, 3> workgroup_size;
std::vector<IntrinsicCallInfo> intrinsic_calls;
// List of transitive calls this function makes
UniqueVector<FunctionInfo*> transitive_calls;
// List of entry point functions that transitively call this function
UniqueVector<FunctionInfo*> ancestor_entry_points;
};
/// Structure holding semantic information about an expression.
@ -183,6 +193,8 @@ class Resolver {
/// @returns true on success, false on error
bool ResolveInternal();
bool ValidatePipelineStages();
/// Creates the nodes and adds them to the sem::Info mappings of the
/// ProgramBuilder.
void CreateSemanticNodes() const;
@ -359,12 +371,18 @@ class Resolver {
/// @param node the AST node.
void Mark(const ast::Node* node);
template <typename CALLBACK>
void TraverseCallChain(FunctionInfo* from,
FunctionInfo* to,
CALLBACK&& callback) const;
ProgramBuilder* const builder_;
diag::List& diagnostics_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
sem::BlockStatement* current_block_ = nullptr;
ScopeStack<VariableInfo*> variable_stack_;
std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
std::vector<FunctionInfo*> entry_points_;
std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;

View File

@ -89,8 +89,11 @@ bool IsBarrierIntrinsic(IntrinsicType i) {
Intrinsic::Intrinsic(IntrinsicType type,
sem::Type* return_type,
const ParameterList& parameters)
: Base(return_type, parameters), type_(type) {}
const ParameterList& parameters,
PipelineStageSet supported_stages)
: Base(return_type, parameters),
type_(type),
supported_stages_(supported_stages) {}
Intrinsic::~Intrinsic() = default;

View File

@ -19,6 +19,7 @@
#include "src/sem/call_target.h"
#include "src/sem/intrinsic_type.h"
#include "src/sem/pipeline_stage_set.h"
namespace tint {
namespace sem {
@ -75,9 +76,12 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
/// @param type the intrinsic type
/// @param return_type the return type for the intrinsic call
/// @param parameters the parameters for the intrinsic overload
/// @param supported_stages the pipeline stages that this intrinsic can be
/// used in
Intrinsic(IntrinsicType type,
sem::Type* return_type,
const ParameterList& parameters);
const ParameterList& parameters,
PipelineStageSet supported_stages);
/// Destructor
~Intrinsic() override;
@ -85,6 +89,9 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
/// @return the type of the intrinsic
IntrinsicType Type() const { return type_; }
/// @return the pipeline stages that this intrinsic can be used in
PipelineStageSet SupportedStages() const { return supported_stages_; }
/// @returns the name of the intrinsic function type. The spelling, including
/// case, matches the name in the WGSL spec.
const char* str() const;
@ -118,6 +125,7 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
private:
IntrinsicType const type_;
PipelineStageSet const supported_stages_;
};
/// Emits the name of the intrinsic function type. The spelling, including case,

View File

@ -0,0 +1,29 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_SEM_PIPELINE_STAGE_SET_H_
#define SRC_SEM_PIPELINE_STAGE_SET_H_
#include "src/ast/pipeline_stage.h"
#include "src/utils/enum_set.h"
namespace tint {
namespace sem {
using PipelineStageSet = utils::EnumSet<ast::PipelineStage>;
} // namespace sem
} // namespace tint
#endif // SRC_SEM_PIPELINE_STAGE_SET_H_

View File

@ -37,6 +37,10 @@ struct UniqueVector {
}
}
/// @returns true if the vector contains `item`
/// @param item the item
bool contains(const T& item) const { return set.count(item); }
/// @returns the number of items in the vector
size_t size() const { return vector.size(); }

View File

@ -163,7 +163,8 @@ TEST_P(HlslIntrinsicTest, Emit) {
auto* call = GenerateCall(param.intrinsic, param.type, this);
ASSERT_NE(nullptr, call) << "Unhandled intrinsic";
WrapInFunction(call);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/call_statement.h"
#include "src/sem/call.h"
#include "src/writer/msl/test_helper.h"
@ -176,7 +177,8 @@ TEST_P(MslIntrinsicTest, Emit) {
auto* call = GenerateCall(param.intrinsic, param.type, this);
ASSERT_NE(nullptr, call) << "Unhandled intrinsic";
WrapInFunction(call);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build();

View File

@ -414,7 +414,8 @@ TEST_P(IntrinsicDeriveTest, Call_Derivative_Scalar) {
auto* var = Global("v", ty.f32(), ast::StorageClass::kPrivate);
auto* expr = Call(param.name, "v");
WrapInFunction(expr);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build();
@ -439,7 +440,8 @@ TEST_P(IntrinsicDeriveTest, Call_Derivative_Vector) {
auto* var = Global("v", ty.vec3<f32>(), ast::StorageClass::kPrivate);
auto* expr = Call(param.name, "v");
WrapInFunction(expr);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build();

View File

@ -3462,7 +3462,9 @@ TEST_P(IntrinsicTextureTest, Call) {
auto* call =
create<ast::CallExpression>(Expr(param.function), param.args(this));
WrapInFunction(call);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build();
@ -3515,7 +3517,8 @@ TEST_P(IntrinsicTextureTest, OutsideFunction_IsError) {
auto* call =
create<ast::CallExpression>(Expr(param.function), param.args(this));
WrapInFunction(call);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build();

View File

@ -232,6 +232,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/resolver/function_validation_test.cc",
"../src/resolver/host_shareable_validation_test.cc",
"../src/resolver/intrinsic_test.cc",
"../src/resolver/intrinsic_validation_test.cc",
"../src/resolver/is_host_shareable_test.cc",
"../src/resolver/is_storeable_test.cc",
"../src/resolver/pipeline_overridable_constant_test.cc",

View File

@ -96,6 +96,8 @@ type Overload struct {
// These indices are consumed by the matchers themselves.
// The first index is always a TypeMatcher.
ReturnMatcherIndicesOffset *int
// StageUses describes the stages an overload can be used in
CanBeUsedInStage sem.StageUses
}
// Function is used to create the C++ IntrinsicInfo structure
@ -193,6 +195,7 @@ func (b *intrinsicTableBuilder) buildOverload(o *sem.Overload) (Overload, error)
OpenNumbersOffset: b.lut.openNumbers.Add(ob.openNumbers),
ParametersOffset: b.lut.parameters.Add(ob.parameters),
ReturnMatcherIndicesOffset: ob.returnTypeMatcherIndicesOffset,
CanBeUsedInStage: o.CanBeUsedInStage,
}, nil
}

View File

@ -146,6 +146,21 @@ type StageUses struct {
Compute bool
}
// List returns the stage uses as a string list
func (u StageUses) List() []string {
out := []string{}
if u.Vertex {
out = append(out, "vertex")
}
if u.Fragment {
out = append(out, "fragment")
}
if u.Compute {
out = append(out, "compute")
}
return out
}
// Format implements the fmt.Formatter interface
func (o Overload) Format(w fmt.State, verb rune) {
fmt.Fprintf(w, "fn %v", o.Function.Name)