resolver: Validate pipline stage use for intrinsics

Use the new [[stage()]] decorations in intrinsics.def to validate that intrinsics are only called from the correct pipeline stages.

Fixed: tint:657
Change-Id: I9efda26369c45c6f816bdaa53408d3909db403a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53084
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-06-03 16:07:34 +00:00 committed by Tint LUCI CQ
parent 7b366475ed
commit 71786c99b3
21 changed files with 611 additions and 148 deletions

View File

@ -511,6 +511,7 @@ libtint_source_set("libtint_core_all_src") {
"sem/node.h", "sem/node.h",
"sem/parameter_usage.cc", "sem/parameter_usage.cc",
"sem/parameter_usage.h", "sem/parameter_usage.h",
"sem/pipeline_stage_set.h",
"sem/pointer_type.cc", "sem/pointer_type.cc",
"sem/pointer_type.h", "sem/pointer_type.h",
"sem/reference_type.cc", "sem/reference_type.cc",

View File

@ -254,6 +254,7 @@ set(TINT_LIB_SRCS
sem/member_accessor_expression.cc sem/member_accessor_expression.cc
sem/parameter_usage.cc sem/parameter_usage.cc
sem/parameter_usage.h sem/parameter_usage.h
sem/pipeline_stage_set.h
sem/node.cc sem/node.cc
sem/node.h sem/node.h
sem/statement.cc sem/statement.cc
@ -580,6 +581,7 @@ if(${TINT_BUILD_TESTS})
resolver/function_validation_test.cc resolver/function_validation_test.cc
resolver/host_shareable_validation_test.cc resolver/host_shareable_validation_test.cc
resolver/intrinsic_test.cc resolver/intrinsic_test.cc
resolver/intrinsic_validation_test.cc
resolver/is_host_shareable_test.cc resolver/is_host_shareable_test.cc
resolver/is_storeable_test.cc resolver/is_storeable_test.cc
resolver/ptr_ref_test.cc resolver/ptr_ref_test.cc

View File

@ -23,6 +23,7 @@
#include "src/sem/depth_texture_type.h" #include "src/sem/depth_texture_type.h"
#include "src/sem/external_texture_type.h" #include "src/sem/external_texture_type.h"
#include "src/sem/multisampled_texture_type.h" #include "src/sem/multisampled_texture_type.h"
#include "src/sem/pipeline_stage_set.h"
#include "src/sem/sampled_texture_type.h" #include "src/sem/sampled_texture_type.h"
#include "src/sem/storage_texture_type.h" #include "src/sem/storage_texture_type.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
@ -288,6 +289,8 @@ using TexelFormat = ast::ImageFormat;
using AccessControl = ast::AccessControl::Access; using AccessControl = ast::AccessControl::Access;
using StorageClass = ast::StorageClass; using StorageClass = ast::StorageClass;
using ParameterUsage = sem::ParameterUsage; using ParameterUsage = sem::ParameterUsage;
using PipelineStageSet = sem::PipelineStageSet;
using PipelineStage = ast::PipelineStage;
bool match_bool(const sem::Type* ty) { bool match_bool(const sem::Type* ty) {
return ty->IsAnyOf<Any, sem::Bool>(); return ty->IsAnyOf<Any, sem::Bool>();
@ -608,6 +611,66 @@ const sem::ExternalTexture* build_texture_external(MatchState& state) {
return state.builder.create<sem::ExternalTexture>(); return state.builder.create<sem::ExternalTexture>();
} }
/// ParameterInfo describes a parameter
struct ParameterInfo {
/// The parameter usage (parameter name in definition file)
ParameterUsage const usage;
/// Pointer to a list of indices that are used to match the parameter type.
/// The matcher indices index on Matchers::type and / or Matchers::number.
/// These indices are consumed by the matchers themselves.
/// The first index is always a TypeMatcher.
MatcherIndex const* const matcher_indices;
};
/// OpenTypeInfo describes an open type
struct OpenTypeInfo {
/// Name of the open type (e.g. 'T')
const char* name;
/// Optional type matcher constraint.
/// Either an index in Matchers::type, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OpenNumberInfo describes an open number
struct OpenNumberInfo {
/// Name of the open number (e.g. 'N')
const char* name;
/// Optional number matcher constraint.
/// Either an index in Matchers::number, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OverloadInfo describes a single function overload
struct OverloadInfo {
/// Total number of parameters for the overload
uint8_t const num_parameters;
/// Total number of open types for the overload
uint8_t const num_open_types;
/// Total number of open numbers for the overload
uint8_t const num_open_numbers;
/// Pointer to the first open type
OpenTypeInfo const* const open_types;
/// Pointer to the first open number
OpenNumberInfo const* const open_numbers;
/// Pointer to the first parameter
ParameterInfo const* const parameters;
/// Pointer to a list of matcher indices that index on Matchers::type and
/// Matchers::number, used to build the return type. If the function has no
/// return type then this is null.
MatcherIndex const* const return_matcher_indices;
/// The pipeline stages that this overload can be used in.
PipelineStageSet supported_stages;
};
/// IntrinsicInfo describes an intrinsic function
struct IntrinsicInfo {
/// Number of overloads of the intrinsic function
uint8_t const num_overloads;
/// Pointer to the start of the overloads for the function
OverloadInfo const* const overloads;
};
#include "intrinsic_table.inl" #include "intrinsic_table.inl"
/// Impl is the private implementation of the IntrinsicTable interface. /// Impl is the private implementation of the IntrinsicTable interface.
@ -807,9 +870,9 @@ const sem::Intrinsic* Impl::Match(sem::IntrinsicType intrinsic_type,
return_type = builder.create<sem::Void>(); return_type = builder.create<sem::Void>();
} }
return builder.create<sem::Intrinsic>(intrinsic_type, return builder.create<sem::Intrinsic>(
const_cast<sem::Type*>(return_type), intrinsic_type, const_cast<sem::Type*>(return_type),
std::move(parameters)); std::move(parameters), overload.supported_stages);
} }
MatchState Impl::Match(ClosedState& closed, MatchState Impl::Match(ClosedState& closed,

File diff suppressed because it is too large Load Diff

View File

@ -11,64 +11,6 @@ See:
// clang-format off // clang-format off
/// ParameterInfo describes a parameter
struct ParameterInfo {
/// The parameter usage (parameter name in definition file)
ParameterUsage const usage;
/// Pointer to a list of indices that are used to match the parameter type.
/// The matcher indices index on Matchers::type and / or Matchers::number.
/// These indices are consumed by the matchers themselves.
/// The first index is always a TypeMatcher.
MatcherIndex const* const matcher_indices;
};
/// OpenTypeInfo describes an open type
struct OpenTypeInfo {
/// Name of the open type (e.g. 'T')
const char* name;
/// Optional type matcher constraint.
/// Either an index in Matchers::type, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OpenNumberInfo describes an open number
struct OpenNumberInfo {
/// Name of the open number (e.g. 'N')
const char* name;
/// Optional number matcher constraint.
/// Either an index in Matchers::number, or kNoMatcher
MatcherIndex const matcher_index;
};
/// OverloadInfo describes a single function overload
struct OverloadInfo {
/// Total number of parameters for the overload
uint8_t const num_parameters;
/// Total number of open types for the overload
uint8_t const num_open_types;
/// Total number of open numbers for the overload
uint8_t const num_open_numbers;
/// Pointer to the first open type
OpenTypeInfo const* const open_types;
/// Pointer to the first open number
OpenNumberInfo const* const open_numbers;
/// Pointer to the first parameter
ParameterInfo const* const parameters;
/// Pointer to a list of matcher indices that index on Matchers::type and
/// Matchers::number, used to build the return type. If the function has no
/// return type then this is null.
MatcherIndex const* const return_matcher_indices;
};
/// IntrinsicInfo describes an intrinsic function
struct IntrinsicInfo {
/// Number of overloads of the intrinsic function
uint8_t const num_overloads;
/// Pointer to the start of the overloads for the function
OverloadInfo const* const overloads;
};
{{ with .Sem -}} {{ with .Sem -}}
{{ range .Types -}} {{ range .Types -}}
{{ template "Type" . }} {{ template "Type" . }}
@ -155,6 +97,10 @@ constexpr OverloadInfo kOverloads[] = {
{{- if $o.ReturnMatcherIndicesOffset }} &kMatcherIndices[{{$o.ReturnMatcherIndicesOffset}}] {{- if $o.ReturnMatcherIndicesOffset }} &kMatcherIndices[{{$o.ReturnMatcherIndicesOffset}}]
{{- else }} nullptr {{- else }} nullptr
{{- end }}, {{- end }},
/* supported_stages */ PipelineStageSet(
{{- range $i, $u := $o.CanBeUsedInStage.List -}}
{{- if $i -}}, {{end}}PipelineStage::k{{Title $u}}
{{- end }}),
}, },
{{- end }} {{- end }}
}; };

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/ast/call_statement.h"
#include "src/resolver/resolver_test_helper.h" #include "src/resolver/resolver_test_helper.h"
namespace tint { namespace tint {
@ -283,7 +284,8 @@ TEST_P(FloatAllMatching, Scalar) {
params.push_back(Expr(1.0f)); params.push_back(Expr(1.0f));
} }
auto* builtin = Call(name, params); auto* builtin = Call(name, params);
WrapInFunction(builtin); Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(builtin)->Is<sem::F32>());
@ -298,7 +300,8 @@ TEST_P(FloatAllMatching, Vec2) {
params.push_back(vec2<f32>(1.0f, 1.0f)); params.push_back(vec2<f32>(1.0f, 1.0f));
} }
auto* builtin = Call(name, params); auto* builtin = Call(name, params);
WrapInFunction(builtin); Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
@ -313,7 +316,8 @@ TEST_P(FloatAllMatching, Vec3) {
params.push_back(vec3<f32>(1.0f, 1.0f, 1.0f)); params.push_back(vec3<f32>(1.0f, 1.0f, 1.0f));
} }
auto* builtin = Call(name, params); auto* builtin = Call(name, params);
WrapInFunction(builtin); Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
@ -328,7 +332,8 @@ TEST_P(FloatAllMatching, Vec4) {
params.push_back(vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f)); params.push_back(vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
} }
auto* builtin = Call(name, params); auto* builtin = Call(name, params);
WrapInFunction(builtin); Func("func", {}, ty.void_(), {create<ast::CallStatement>(builtin)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_TRUE(TypeOf(builtin)->is_float_vector()); EXPECT_TRUE(TypeOf(builtin)->is_float_vector());

View File

@ -55,7 +55,8 @@ TEST_P(ResolverIntrinsicDerivativeTest, Scalar) {
Global("ident", ty.f32(), ast::StorageClass::kInput); Global("ident", ty.f32(), ast::StorageClass::kInput);
auto* expr = Call(name, "ident"); auto* expr = Call(name, "ident");
WrapInFunction(expr); Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -68,7 +69,8 @@ TEST_P(ResolverIntrinsicDerivativeTest, Vector) {
Global("ident", ty.vec4<f32>(), ast::StorageClass::kInput); Global("ident", ty.vec4<f32>(), ast::StorageClass::kInput);
auto* expr = Call(name, "ident"); auto* expr = Call(name, "ident");
WrapInFunction(expr); Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -1927,7 +1929,8 @@ TEST_P(ResolverIntrinsicTest_Texture, Call) {
param.buildSamplerVariable(this); param.buildSamplerVariable(this);
auto* call = Call(param.function, param.args(this)); auto* call = Call(param.function, param.args(this));
WrapInFunction(call); Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();

View File

@ -0,0 +1,97 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/resolver/resolver.h"
#include "gmock/gmock.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic_texture_helper_test.h"
#include "src/ast/loop_statement.h"
#include "src/ast/return_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/switch_statement.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/resolver/resolver_test_helper.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/member_accessor_expression.h"
#include "src/sem/sampled_texture_type.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
using ::testing::ElementsAre;
using ::testing::HasSubstr;
namespace tint {
namespace resolver {
namespace {
using IntrinsicType = sem::IntrinsicType;
using ResolverIntrinsicValidationTest = ResolverTest;
TEST_F(ResolverIntrinsicValidationTest, InvalidPipelineStageDirect) {
// [[stage(compute)]] fn func { return dpdx(1.0); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
ast::ExpressionList{Expr(1.0f)});
Func(Source{{1, 2}}, "func", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(dpdx)},
{Stage(ast::PipelineStage::kCompute)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"3:4 error: built-in cannot be used by compute pipeline stage");
}
TEST_F(ResolverIntrinsicValidationTest, InvalidPipelineStageIndirect) {
// fn f0 { return dpdx(1.0); }
// fn f1 { f0(); }
// fn f2 { f1(); }
// [[stage(compute)]] fn main { return f2(); }
auto* dpdx = create<ast::CallExpression>(Source{{3, 4}}, Expr("dpdx"),
ast::ExpressionList{Expr(1.0f)});
Func(Source{{1, 2}}, "f0", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(dpdx)});
Func(Source{{3, 4}}, "f1", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(Call("f0"))});
Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(Call("f1"))});
Func(Source{{7, 8}}, "main", ast::VariableList{}, ty.void_(),
{create<ast::CallStatement>(Call("f2"))},
{Stage(ast::PipelineStage::kCompute)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(3:4 error: built-in cannot be used by compute pipeline stage
1:2 note: called by function 'f0'
3:4 note: called by function 'f1'
5:6 note: called by function 'f2'
7:8 note: called by entry point 'main')");
}
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -237,6 +237,10 @@ bool Resolver::ResolveInternal() {
} }
} }
if (!ValidatePipelineStages()) {
return false;
}
bool result = true; bool result = true;
for (auto* node : builder_->ASTNodes().Objects()) { for (auto* node : builder_->ASTNodes().Objects()) {
@ -1129,6 +1133,10 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
bool Resolver::Function(ast::Function* func) { bool Resolver::Function(ast::Function* func) {
auto* info = function_infos_.Create<FunctionInfo>(func); auto* info = function_infos_.Create<FunctionInfo>(func);
if (func->IsEntryPoint()) {
entry_points_.emplace_back(info);
}
TINT_SCOPED_ASSIGNMENT(current_function_, info); TINT_SCOPED_ASSIGNMENT(current_function_, info);
variable_stack_.push_scope(); variable_stack_.push_scope();
@ -1707,6 +1715,10 @@ bool Resolver::IntrinsicCall(ast::CallExpression* call,
builder_->Sem().Add( builder_->Sem().Add(
call, builder_->create<sem::Call>(call, result, current_statement_)); call, builder_->create<sem::Call>(call, result, current_statement_));
SetType(call, result->ReturnType()); SetType(call, result->ReturnType());
current_function_->intrinsic_calls.emplace_back(
IntrinsicCallInfo{call, result});
return true; return true;
} }
@ -2460,25 +2472,76 @@ void Resolver::SetType(ast::Expression* expr,
expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_}); expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_});
} }
bool Resolver::ValidatePipelineStages() {
auto check_intrinsic_calls = [&](FunctionInfo* func,
FunctionInfo* entry_point) {
auto stage = entry_point->declaration->pipeline_stage();
for (auto& call : func->intrinsic_calls) {
if (!call.intrinsic->SupportedStages().Contains(stage)) {
std::stringstream err;
err << "built-in cannot be used by " << stage << " pipeline stage";
diagnostics_.add_error(err.str(), call.call->source());
if (func != entry_point) {
TraverseCallChain(entry_point, func, [&](FunctionInfo* f) {
diagnostics_.add_note(
"called by function '" +
builder_->Symbols().NameFor(f->declaration->symbol()) + "'",
f->declaration->source());
});
diagnostics_.add_note("called by entry point '" +
builder_->Symbols().NameFor(
entry_point->declaration->symbol()) +
"'",
entry_point->declaration->source());
}
return false;
}
}
return true;
};
for (auto* entry_point : entry_points_) {
if (!check_intrinsic_calls(entry_point, entry_point)) {
return false;
}
for (auto* func : entry_point->transitive_calls) {
if (!check_intrinsic_calls(func, entry_point)) {
return false;
}
}
}
return true;
}
template <typename CALLBACK>
void Resolver::TraverseCallChain(FunctionInfo* from,
FunctionInfo* to,
CALLBACK&& callback) const {
for (auto* f : from->transitive_calls) {
if (f == to) {
callback(f);
return;
}
if (f->transitive_calls.contains(to)) {
TraverseCallChain(f, to, callback);
callback(f);
return;
}
}
TINT_ICE(diagnostics_)
<< "TraverseCallChain() 'from' does not transitively call 'to'";
}
void Resolver::CreateSemanticNodes() const { void Resolver::CreateSemanticNodes() const {
auto& sem = builder_->Sem(); auto& sem = builder_->Sem();
// Collate all the 'ancestor_entry_points' - this is a map of function // Collate all the 'ancestor_entry_points' - this is a map of function
// symbol to all the entry points that transitively call the function. // symbol to all the entry points that transitively call the function.
std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points; std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
for (auto* func : builder_->AST().Functions()) { for (auto* entry_point : entry_points_) {
auto it = function_to_info_.find(func); for (auto* call : entry_point->transitive_calls) {
if (it == function_to_info_.end()) {
continue; // Resolver has likely errored. Process what we can.
}
auto* info = it->second;
if (!func->IsEntryPoint()) {
continue;
}
for (auto* call : info->transitive_calls) {
auto& vec = ancestor_entry_points[call->declaration->symbol()]; auto& vec = ancestor_entry_points[call->declaration->symbol()];
vec.emplace_back(func->symbol()); vec.emplace_back(entry_point->declaration->symbol());
} }
} }

View File

@ -51,6 +51,7 @@ class Variable;
} // namespace ast } // namespace ast
namespace sem { namespace sem {
class Array; class Array;
class Intrinsic;
class Statement; class Statement;
} // namespace sem } // namespace sem
@ -100,6 +101,11 @@ class Resolver {
sem::BindingPoint binding_point; sem::BindingPoint binding_point;
}; };
struct IntrinsicCallInfo {
const ast::CallExpression* call;
const sem::Intrinsic* intrinsic;
};
/// Structure holding semantic information about a function. /// Structure holding semantic information about a function.
/// Used to build the sem::Function nodes at the end of resolving. /// Used to build the sem::Function nodes at the end of resolving.
struct FunctionInfo { struct FunctionInfo {
@ -115,9 +121,13 @@ class Resolver {
sem::Type* return_type = nullptr; sem::Type* return_type = nullptr;
std::string return_type_name; std::string return_type_name;
std::array<sem::WorkgroupDimension, 3> workgroup_size; std::array<sem::WorkgroupDimension, 3> workgroup_size;
std::vector<IntrinsicCallInfo> intrinsic_calls;
// List of transitive calls this function makes // List of transitive calls this function makes
UniqueVector<FunctionInfo*> transitive_calls; UniqueVector<FunctionInfo*> transitive_calls;
// List of entry point functions that transitively call this function
UniqueVector<FunctionInfo*> ancestor_entry_points;
}; };
/// Structure holding semantic information about an expression. /// Structure holding semantic information about an expression.
@ -183,6 +193,8 @@ class Resolver {
/// @returns true on success, false on error /// @returns true on success, false on error
bool ResolveInternal(); bool ResolveInternal();
bool ValidatePipelineStages();
/// Creates the nodes and adds them to the sem::Info mappings of the /// Creates the nodes and adds them to the sem::Info mappings of the
/// ProgramBuilder. /// ProgramBuilder.
void CreateSemanticNodes() const; void CreateSemanticNodes() const;
@ -359,12 +371,18 @@ class Resolver {
/// @param node the AST node. /// @param node the AST node.
void Mark(const ast::Node* node); void Mark(const ast::Node* node);
template <typename CALLBACK>
void TraverseCallChain(FunctionInfo* from,
FunctionInfo* to,
CALLBACK&& callback) const;
ProgramBuilder* const builder_; ProgramBuilder* const builder_;
diag::List& diagnostics_; diag::List& diagnostics_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_; std::unique_ptr<IntrinsicTable> const intrinsic_table_;
sem::BlockStatement* current_block_ = nullptr; sem::BlockStatement* current_block_ = nullptr;
ScopeStack<VariableInfo*> variable_stack_; ScopeStack<VariableInfo*> variable_stack_;
std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_; std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
std::vector<FunctionInfo*> entry_points_;
std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_; std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_; std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_; std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;

View File

@ -89,8 +89,11 @@ bool IsBarrierIntrinsic(IntrinsicType i) {
Intrinsic::Intrinsic(IntrinsicType type, Intrinsic::Intrinsic(IntrinsicType type,
sem::Type* return_type, sem::Type* return_type,
const ParameterList& parameters) const ParameterList& parameters,
: Base(return_type, parameters), type_(type) {} PipelineStageSet supported_stages)
: Base(return_type, parameters),
type_(type),
supported_stages_(supported_stages) {}
Intrinsic::~Intrinsic() = default; Intrinsic::~Intrinsic() = default;

View File

@ -19,6 +19,7 @@
#include "src/sem/call_target.h" #include "src/sem/call_target.h"
#include "src/sem/intrinsic_type.h" #include "src/sem/intrinsic_type.h"
#include "src/sem/pipeline_stage_set.h"
namespace tint { namespace tint {
namespace sem { namespace sem {
@ -75,9 +76,12 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
/// @param type the intrinsic type /// @param type the intrinsic type
/// @param return_type the return type for the intrinsic call /// @param return_type the return type for the intrinsic call
/// @param parameters the parameters for the intrinsic overload /// @param parameters the parameters for the intrinsic overload
/// @param supported_stages the pipeline stages that this intrinsic can be
/// used in
Intrinsic(IntrinsicType type, Intrinsic(IntrinsicType type,
sem::Type* return_type, sem::Type* return_type,
const ParameterList& parameters); const ParameterList& parameters,
PipelineStageSet supported_stages);
/// Destructor /// Destructor
~Intrinsic() override; ~Intrinsic() override;
@ -85,6 +89,9 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
/// @return the type of the intrinsic /// @return the type of the intrinsic
IntrinsicType Type() const { return type_; } IntrinsicType Type() const { return type_; }
/// @return the pipeline stages that this intrinsic can be used in
PipelineStageSet SupportedStages() const { return supported_stages_; }
/// @returns the name of the intrinsic function type. The spelling, including /// @returns the name of the intrinsic function type. The spelling, including
/// case, matches the name in the WGSL spec. /// case, matches the name in the WGSL spec.
const char* str() const; const char* str() const;
@ -118,6 +125,7 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
private: private:
IntrinsicType const type_; IntrinsicType const type_;
PipelineStageSet const supported_stages_;
}; };
/// Emits the name of the intrinsic function type. The spelling, including case, /// Emits the name of the intrinsic function type. The spelling, including case,

View File

@ -0,0 +1,29 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_SEM_PIPELINE_STAGE_SET_H_
#define SRC_SEM_PIPELINE_STAGE_SET_H_
#include "src/ast/pipeline_stage.h"
#include "src/utils/enum_set.h"
namespace tint {
namespace sem {
using PipelineStageSet = utils::EnumSet<ast::PipelineStage>;
} // namespace sem
} // namespace tint
#endif // SRC_SEM_PIPELINE_STAGE_SET_H_

View File

@ -37,6 +37,10 @@ struct UniqueVector {
} }
} }
/// @returns true if the vector contains `item`
/// @param item the item
bool contains(const T& item) const { return set.count(item); }
/// @returns the number of items in the vector /// @returns the number of items in the vector
size_t size() const { return vector.size(); } size_t size() const { return vector.size(); }
@ -47,7 +51,7 @@ struct UniqueVector {
ConstIterator end() const { return vector.end(); } ConstIterator end() const { return vector.end(); }
/// @returns a const reference to the internal vector /// @returns a const reference to the internal vector
operator const std::vector<T>&() const { return vector; } operator const std::vector<T> &() const { return vector; }
private: private:
std::vector<T> vector; std::vector<T> vector;

View File

@ -163,7 +163,8 @@ TEST_P(HlslIntrinsicTest, Emit) {
auto* call = GenerateCall(param.intrinsic, param.type, this); auto* call = GenerateCall(param.intrinsic, param.type, this);
ASSERT_NE(nullptr, call) << "Unhandled intrinsic"; ASSERT_NE(nullptr, call) << "Unhandled intrinsic";
WrapInFunction(call); Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/ast/call_statement.h"
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/writer/msl/test_helper.h" #include "src/writer/msl/test_helper.h"
@ -176,7 +177,8 @@ TEST_P(MslIntrinsicTest, Emit) {
auto* call = GenerateCall(param.intrinsic, param.type, this); auto* call = GenerateCall(param.intrinsic, param.type, this);
ASSERT_NE(nullptr, call) << "Unhandled intrinsic"; ASSERT_NE(nullptr, call) << "Unhandled intrinsic";
WrapInFunction(call); Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -414,7 +414,8 @@ TEST_P(IntrinsicDeriveTest, Call_Derivative_Scalar) {
auto* var = Global("v", ty.f32(), ast::StorageClass::kPrivate); auto* var = Global("v", ty.f32(), ast::StorageClass::kPrivate);
auto* expr = Call(param.name, "v"); auto* expr = Call(param.name, "v");
WrapInFunction(expr); Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build(); spirv::Builder& b = Build();
@ -439,7 +440,8 @@ TEST_P(IntrinsicDeriveTest, Call_Derivative_Vector) {
auto* var = Global("v", ty.vec3<f32>(), ast::StorageClass::kPrivate); auto* var = Global("v", ty.vec3<f32>(), ast::StorageClass::kPrivate);
auto* expr = Call(param.name, "v"); auto* expr = Call(param.name, "v");
WrapInFunction(expr); Func("func", {}, ty.void_(), {create<ast::CallStatement>(expr)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build(); spirv::Builder& b = Build();

View File

@ -3462,7 +3462,9 @@ TEST_P(IntrinsicTextureTest, Call) {
auto* call = auto* call =
create<ast::CallExpression>(Expr(param.function), param.args(this)); create<ast::CallExpression>(Expr(param.function), param.args(this));
WrapInFunction(call);
Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build(); spirv::Builder& b = Build();
@ -3515,7 +3517,8 @@ TEST_P(IntrinsicTextureTest, OutsideFunction_IsError) {
auto* call = auto* call =
create<ast::CallExpression>(Expr(param.function), param.args(this)); create<ast::CallExpression>(Expr(param.function), param.args(this));
WrapInFunction(call); Func("func", {}, ty.void_(), {create<ast::CallStatement>(call)},
{create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
spirv::Builder& b = Build(); spirv::Builder& b = Build();

View File

@ -232,6 +232,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/resolver/function_validation_test.cc", "../src/resolver/function_validation_test.cc",
"../src/resolver/host_shareable_validation_test.cc", "../src/resolver/host_shareable_validation_test.cc",
"../src/resolver/intrinsic_test.cc", "../src/resolver/intrinsic_test.cc",
"../src/resolver/intrinsic_validation_test.cc",
"../src/resolver/is_host_shareable_test.cc", "../src/resolver/is_host_shareable_test.cc",
"../src/resolver/is_storeable_test.cc", "../src/resolver/is_storeable_test.cc",
"../src/resolver/pipeline_overridable_constant_test.cc", "../src/resolver/pipeline_overridable_constant_test.cc",

View File

@ -96,6 +96,8 @@ type Overload struct {
// These indices are consumed by the matchers themselves. // These indices are consumed by the matchers themselves.
// The first index is always a TypeMatcher. // The first index is always a TypeMatcher.
ReturnMatcherIndicesOffset *int ReturnMatcherIndicesOffset *int
// StageUses describes the stages an overload can be used in
CanBeUsedInStage sem.StageUses
} }
// Function is used to create the C++ IntrinsicInfo structure // Function is used to create the C++ IntrinsicInfo structure
@ -193,6 +195,7 @@ func (b *intrinsicTableBuilder) buildOverload(o *sem.Overload) (Overload, error)
OpenNumbersOffset: b.lut.openNumbers.Add(ob.openNumbers), OpenNumbersOffset: b.lut.openNumbers.Add(ob.openNumbers),
ParametersOffset: b.lut.parameters.Add(ob.parameters), ParametersOffset: b.lut.parameters.Add(ob.parameters),
ReturnMatcherIndicesOffset: ob.returnTypeMatcherIndicesOffset, ReturnMatcherIndicesOffset: ob.returnTypeMatcherIndicesOffset,
CanBeUsedInStage: o.CanBeUsedInStage,
}, nil }, nil
} }

View File

@ -146,6 +146,21 @@ type StageUses struct {
Compute bool Compute bool
} }
// List returns the stage uses as a string list
func (u StageUses) List() []string {
out := []string{}
if u.Vertex {
out = append(out, "vertex")
}
if u.Fragment {
out = append(out, "fragment")
}
if u.Compute {
out = append(out, "compute")
}
return out
}
// Format implements the fmt.Formatter interface // Format implements the fmt.Formatter interface
func (o Overload) Format(w fmt.State, verb rune) { func (o Overload) Format(w fmt.State, verb rune) {
fmt.Fprintf(w, "fn %v", o.Function.Name) fmt.Fprintf(w, "fn %v", o.Function.Name)