sem::Function: Add ReturnType()
This is the resolved, semantic, return type of the function. Bug: tint:724 Change-Id: I4ef9f7874414a3ea48131d0102da776f6d82a729 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49526 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
fa9af4b0ef
commit
3068dcb3d7
|
@ -588,7 +588,7 @@ bool Resolver::ValidateFunction(const ast::Function* func,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!func->return_type()->Is<sem::Void>()) {
|
if (!info->return_type->Is<sem::Void>()) {
|
||||||
if (func->body()) {
|
if (func->body()) {
|
||||||
if (!func->get_last_statement() ||
|
if (!func->get_last_statement() ||
|
||||||
!func->get_last_statement()->Is<ast::ReturnStatement>()) {
|
!func->get_last_statement()->Is<ast::ReturnStatement>()) {
|
||||||
|
@ -809,9 +809,9 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
|
||||||
builtins.clear();
|
builtins.clear();
|
||||||
locations.clear();
|
locations.clear();
|
||||||
|
|
||||||
if (!func->return_type()->Is<sem::Void>()) {
|
if (!info->return_type->Is<sem::Void>()) {
|
||||||
if (!validate_entry_point_decorations(func->return_type_decorations(),
|
if (!validate_entry_point_decorations(func->return_type_decorations(),
|
||||||
func->return_type(), func->source(),
|
info->return_type, func->source(),
|
||||||
ParamOrRetType::kReturnType)) {
|
ParamOrRetType::kReturnType)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -844,9 +844,9 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Resolver::Function(ast::Function* func) {
|
bool Resolver::Function(ast::Function* func) {
|
||||||
auto* func_info = function_infos_.Create<FunctionInfo>(func);
|
auto* info = function_infos_.Create<FunctionInfo>(func);
|
||||||
|
|
||||||
ScopedAssignment<FunctionInfo*> sa(current_function_, func_info);
|
ScopedAssignment<FunctionInfo*> sa(current_function_, info);
|
||||||
|
|
||||||
variable_stack_.push_scope();
|
variable_stack_.push_scope();
|
||||||
for (auto* param : func->params()) {
|
for (auto* param : func->params()) {
|
||||||
|
@ -862,11 +862,10 @@ bool Resolver::Function(ast::Function* func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
variable_stack_.set(param->symbol(), param_info);
|
variable_stack_.set(param->symbol(), param_info);
|
||||||
func_info->parameters.emplace_back(param_info);
|
info->parameters.emplace_back(param_info);
|
||||||
|
|
||||||
if (!ApplyStorageClassUsageToType(param->declared_storage_class(),
|
if (!ApplyStorageClassUsageToType(param->declared_storage_class(),
|
||||||
param->declared_type(),
|
param_info->type, param->source())) {
|
||||||
param->source())) {
|
|
||||||
diagnostics_.add_note("while instantiating parameter " +
|
diagnostics_.add_note("while instantiating parameter " +
|
||||||
builder_->Symbols().NameFor(param->symbol()),
|
builder_->Symbols().NameFor(param->symbol()),
|
||||||
param->source());
|
param->source());
|
||||||
|
@ -874,21 +873,21 @@ bool Resolver::Function(ast::Function* func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* str = param_info->type->As<sem::StructType>()) {
|
if (auto* str = param_info->type->As<sem::StructType>()) {
|
||||||
auto* info = Structure(str);
|
auto* str_info = Structure(str);
|
||||||
if (!info) {
|
if (!str_info) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
switch (func->pipeline_stage()) {
|
switch (func->pipeline_stage()) {
|
||||||
case ast::PipelineStage::kVertex:
|
case ast::PipelineStage::kVertex:
|
||||||
info->pipeline_stage_uses.emplace(
|
str_info->pipeline_stage_uses.emplace(
|
||||||
sem::PipelineStageUsage::kVertexInput);
|
sem::PipelineStageUsage::kVertexInput);
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kFragment:
|
case ast::PipelineStage::kFragment:
|
||||||
info->pipeline_stage_uses.emplace(
|
str_info->pipeline_stage_uses.emplace(
|
||||||
sem::PipelineStageUsage::kFragmentInput);
|
sem::PipelineStageUsage::kFragmentInput);
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kCompute:
|
case ast::PipelineStage::kCompute:
|
||||||
info->pipeline_stage_uses.emplace(
|
str_info->pipeline_stage_uses.emplace(
|
||||||
sem::PipelineStageUsage::kComputeInput);
|
sem::PipelineStageUsage::kComputeInput);
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kNone:
|
case ast::PipelineStage::kNone:
|
||||||
|
@ -897,7 +896,22 @@ bool Resolver::Function(ast::Function* func) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* str = Canonical(func->return_type())->As<sem::StructType>()) {
|
if (func->return_type().ast || func->return_type().sem) {
|
||||||
|
info->return_type = func->return_type();
|
||||||
|
if (!info->return_type) {
|
||||||
|
info->return_type = Type(func->return_type().ast);
|
||||||
|
}
|
||||||
|
if (!info->return_type) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info->return_type = builder_->create<sem::Void>();
|
||||||
|
}
|
||||||
|
|
||||||
|
info->return_type_name = info->return_type->FriendlyName(builder_->Symbols());
|
||||||
|
info->return_type = Canonical(info->return_type);
|
||||||
|
|
||||||
|
if (auto* str = info->return_type->As<sem::StructType>()) {
|
||||||
if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
|
if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
|
||||||
func->source())) {
|
func->source())) {
|
||||||
diagnostics_.add_note("while instantiating return type for " +
|
diagnostics_.add_note("while instantiating return type for " +
|
||||||
|
@ -906,21 +920,21 @@ bool Resolver::Function(ast::Function* func) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* info = Structure(str);
|
auto* str_info = Structure(str);
|
||||||
if (!info) {
|
if (!str_info) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
switch (func->pipeline_stage()) {
|
switch (func->pipeline_stage()) {
|
||||||
case ast::PipelineStage::kVertex:
|
case ast::PipelineStage::kVertex:
|
||||||
info->pipeline_stage_uses.emplace(
|
str_info->pipeline_stage_uses.emplace(
|
||||||
sem::PipelineStageUsage::kVertexOutput);
|
sem::PipelineStageUsage::kVertexOutput);
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kFragment:
|
case ast::PipelineStage::kFragment:
|
||||||
info->pipeline_stage_uses.emplace(
|
str_info->pipeline_stage_uses.emplace(
|
||||||
sem::PipelineStageUsage::kFragmentOutput);
|
sem::PipelineStageUsage::kFragmentOutput);
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kCompute:
|
case ast::PipelineStage::kCompute:
|
||||||
info->pipeline_stage_uses.emplace(
|
str_info->pipeline_stage_uses.emplace(
|
||||||
sem::PipelineStageUsage::kComputeOutput);
|
sem::PipelineStageUsage::kComputeOutput);
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kNone:
|
case ast::PipelineStage::kNone:
|
||||||
|
@ -943,15 +957,15 @@ bool Resolver::Function(ast::Function* func) {
|
||||||
Mark(deco);
|
Mark(deco);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ValidateFunction(func, func_info)) {
|
if (!ValidateFunction(func, info)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the function information _after_ processing the statements. This
|
// Register the function information _after_ processing the statements. This
|
||||||
// allows us to catch a function calling itself when determining the call
|
// allows us to catch a function calling itself when determining the call
|
||||||
// information as this function doesn't exist until it's finished.
|
// information as this function doesn't exist until it's finished.
|
||||||
symbol_to_function_[func->symbol()] = func_info;
|
symbol_to_function_[func->symbol()] = info;
|
||||||
function_to_info_.emplace(func, func_info);
|
function_to_info_.emplace(func, info);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1274,7 +1288,7 @@ bool Resolver::Call(ast::CallExpression* call) {
|
||||||
auto* function = iter->second;
|
auto* function = iter->second;
|
||||||
function_calls_.emplace(call,
|
function_calls_.emplace(call,
|
||||||
FunctionCallInfo{function, current_statement_});
|
FunctionCallInfo{function, current_statement_});
|
||||||
SetType(call, function->declaration->return_type());
|
SetType(call, function->return_type, function->return_type_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -2093,8 +2107,8 @@ void Resolver::CreateSemanticNodes() const {
|
||||||
auto* info = it.second;
|
auto* info = it.second;
|
||||||
|
|
||||||
auto* sem_func = builder_->create<sem::Function>(
|
auto* sem_func = builder_->create<sem::Function>(
|
||||||
info->declaration, remap_vars(info->parameters),
|
info->declaration, const_cast<sem::Type*>(info->return_type),
|
||||||
remap_vars(info->referenced_module_vars),
|
remap_vars(info->parameters), remap_vars(info->referenced_module_vars),
|
||||||
remap_vars(info->local_referenced_module_vars), info->return_statements,
|
remap_vars(info->local_referenced_module_vars), info->return_statements,
|
||||||
ancestor_entry_points[func->symbol()]);
|
ancestor_entry_points[func->symbol()]);
|
||||||
func_info_to_sem_func.emplace(info, sem_func);
|
func_info_to_sem_func.emplace(info, sem_func);
|
||||||
|
@ -2479,19 +2493,19 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
|
bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
|
||||||
sem::Type* func_type = current_function_->declaration->return_type();
|
auto* func_type = current_function_->return_type;
|
||||||
|
|
||||||
auto* ret_type = ret->has_value() ? TypeOf(ret->value())->UnwrapAll()
|
auto* ret_type = ret->has_value() ? TypeOf(ret->value())->UnwrapAll()
|
||||||
: builder_->ty.void_();
|
: builder_->ty.void_();
|
||||||
|
|
||||||
if (func_type->UnwrapAll() != ret_type) {
|
if (func_type->UnwrapAll() != ret_type) {
|
||||||
diagnostics_.add_error(
|
diagnostics_.add_error("v-000y",
|
||||||
"v-000y",
|
"return statement type must match its function "
|
||||||
"return statement type must match its function "
|
"return type, returned '" +
|
||||||
"return type, returned '" +
|
ret_type->FriendlyName(builder_->Symbols()) +
|
||||||
ret_type->FriendlyName(builder_->Symbols()) + "', expected '" +
|
"', expected '" +
|
||||||
func_type->FriendlyName(builder_->Symbols()) + "'",
|
current_function_->return_type_name + "'",
|
||||||
ret->source());
|
ret->source());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -119,6 +119,8 @@ class Resolver {
|
||||||
UniqueVector<VariableInfo*> referenced_module_vars;
|
UniqueVector<VariableInfo*> referenced_module_vars;
|
||||||
UniqueVector<VariableInfo*> local_referenced_module_vars;
|
UniqueVector<VariableInfo*> local_referenced_module_vars;
|
||||||
std::vector<const ast::ReturnStatement*> return_statements;
|
std::vector<const ast::ReturnStatement*> return_statements;
|
||||||
|
sem::Type const* return_type = nullptr;
|
||||||
|
std::string return_type_name;
|
||||||
|
|
||||||
// List of transitive calls this function makes
|
// List of transitive calls this function makes
|
||||||
UniqueVector<FunctionInfo*> transitive_calls;
|
UniqueVector<FunctionInfo*> transitive_calls;
|
||||||
|
|
|
@ -758,6 +758,7 @@ TEST_F(ResolverTest, Function_Parameters) {
|
||||||
EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
|
EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
|
||||||
EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
|
EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
|
||||||
EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
|
EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
|
||||||
|
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
|
TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
|
||||||
|
@ -785,6 +786,7 @@ TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
|
||||||
auto* func_sem = Sem().Get(func);
|
auto* func_sem = Sem().Get(func);
|
||||||
ASSERT_NE(func_sem, nullptr);
|
ASSERT_NE(func_sem, nullptr);
|
||||||
EXPECT_EQ(func_sem->Parameters().size(), 0u);
|
EXPECT_EQ(func_sem->Parameters().size(), 0u);
|
||||||
|
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
|
||||||
|
|
||||||
const auto& vars = func_sem->ReferencedModuleVariables();
|
const auto& vars = func_sem->ReferencedModuleVariables();
|
||||||
ASSERT_EQ(vars.size(), 5u);
|
ASSERT_EQ(vars.size(), 5u);
|
||||||
|
@ -851,6 +853,7 @@ TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) {
|
||||||
ASSERT_NE(func_sem, nullptr);
|
ASSERT_NE(func_sem, nullptr);
|
||||||
|
|
||||||
EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
|
EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
|
||||||
|
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTest, Function_ReturnStatements) {
|
TEST_F(ResolverTest, Function_ReturnStatements) {
|
||||||
|
@ -875,6 +878,7 @@ TEST_F(ResolverTest, Function_ReturnStatements) {
|
||||||
EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
|
EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
|
||||||
EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
|
EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
|
||||||
EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo);
|
EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo);
|
||||||
|
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::F32>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
|
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
|
||||||
|
|
|
@ -41,12 +41,13 @@ ParameterList GetParameters(ast::Function* ast) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Function::Function(ast::Function* declaration,
|
Function::Function(ast::Function* declaration,
|
||||||
|
Type* return_type,
|
||||||
std::vector<const Variable*> parameters,
|
std::vector<const Variable*> parameters,
|
||||||
std::vector<const Variable*> referenced_module_vars,
|
std::vector<const Variable*> referenced_module_vars,
|
||||||
std::vector<const Variable*> local_referenced_module_vars,
|
std::vector<const Variable*> local_referenced_module_vars,
|
||||||
std::vector<const ast::ReturnStatement*> return_statements,
|
std::vector<const ast::ReturnStatement*> return_statements,
|
||||||
std::vector<Symbol> ancestor_entry_points)
|
std::vector<Symbol> ancestor_entry_points)
|
||||||
: Base(declaration->return_type(), GetParameters(declaration)),
|
: Base(return_type, GetParameters(declaration)),
|
||||||
declaration_(declaration),
|
declaration_(declaration),
|
||||||
parameters_(std::move(parameters)),
|
parameters_(std::move(parameters)),
|
||||||
referenced_module_vars_(std::move(referenced_module_vars)),
|
referenced_module_vars_(std::move(referenced_module_vars)),
|
||||||
|
@ -138,8 +139,7 @@ Function::VariableBindings Function::ReferencedStorageTextureVariables() const {
|
||||||
VariableBindings ret;
|
VariableBindings ret;
|
||||||
|
|
||||||
for (auto* var : ReferencedModuleVariables()) {
|
for (auto* var : ReferencedModuleVariables()) {
|
||||||
auto* unwrapped_type =
|
auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
|
||||||
var->Declaration()->declared_type()->UnwrapIfNeeded();
|
|
||||||
auto* storage_texture = unwrapped_type->As<sem::StorageTexture>();
|
auto* storage_texture = unwrapped_type->As<sem::StorageTexture>();
|
||||||
if (storage_texture == nullptr) {
|
if (storage_texture == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -156,8 +156,7 @@ Function::VariableBindings Function::ReferencedDepthTextureVariables() const {
|
||||||
VariableBindings ret;
|
VariableBindings ret;
|
||||||
|
|
||||||
for (auto* var : ReferencedModuleVariables()) {
|
for (auto* var : ReferencedModuleVariables()) {
|
||||||
auto* unwrapped_type =
|
auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
|
||||||
var->Declaration()->declared_type()->UnwrapIfNeeded();
|
|
||||||
auto* storage_texture = unwrapped_type->As<sem::DepthTexture>();
|
auto* storage_texture = unwrapped_type->As<sem::DepthTexture>();
|
||||||
if (storage_texture == nullptr) {
|
if (storage_texture == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -184,8 +183,7 @@ Function::VariableBindings Function::ReferencedSamplerVariablesImpl(
|
||||||
VariableBindings ret;
|
VariableBindings ret;
|
||||||
|
|
||||||
for (auto* var : ReferencedModuleVariables()) {
|
for (auto* var : ReferencedModuleVariables()) {
|
||||||
auto* unwrapped_type =
|
auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
|
||||||
var->Declaration()->declared_type()->UnwrapIfNeeded();
|
|
||||||
auto* sampler = unwrapped_type->As<sem::Sampler>();
|
auto* sampler = unwrapped_type->As<sem::Sampler>();
|
||||||
if (sampler == nullptr || sampler->kind() != kind) {
|
if (sampler == nullptr || sampler->kind() != kind) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -203,8 +201,7 @@ Function::VariableBindings Function::ReferencedSampledTextureVariablesImpl(
|
||||||
VariableBindings ret;
|
VariableBindings ret;
|
||||||
|
|
||||||
for (auto* var : ReferencedModuleVariables()) {
|
for (auto* var : ReferencedModuleVariables()) {
|
||||||
auto* unwrapped_type =
|
auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
|
||||||
var->Declaration()->declared_type()->UnwrapIfNeeded();
|
|
||||||
auto* texture = unwrapped_type->As<sem::Texture>();
|
auto* texture = unwrapped_type->As<sem::Texture>();
|
||||||
if (texture == nullptr) {
|
if (texture == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -46,6 +46,7 @@ class Function : public Castable<Function, CallTarget> {
|
||||||
|
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param declaration the ast::Function
|
/// @param declaration the ast::Function
|
||||||
|
/// @param return_type the return type of the function
|
||||||
/// @param parameters the parameters to the function
|
/// @param parameters the parameters to the function
|
||||||
/// @param referenced_module_vars the referenced module variables
|
/// @param referenced_module_vars the referenced module variables
|
||||||
/// @param local_referenced_module_vars the locally referenced module
|
/// @param local_referenced_module_vars the locally referenced module
|
||||||
|
@ -53,6 +54,7 @@ class Function : public Castable<Function, CallTarget> {
|
||||||
/// variables
|
/// variables
|
||||||
/// @param ancestor_entry_points the ancestor entry points
|
/// @param ancestor_entry_points the ancestor entry points
|
||||||
Function(ast::Function* declaration,
|
Function(ast::Function* declaration,
|
||||||
|
Type* return_type,
|
||||||
std::vector<const Variable*> parameters,
|
std::vector<const Variable*> parameters,
|
||||||
std::vector<const Variable*> referenced_module_vars,
|
std::vector<const Variable*> referenced_module_vars,
|
||||||
std::vector<const Variable*> local_referenced_module_vars,
|
std::vector<const Variable*> local_referenced_module_vars,
|
||||||
|
|
Loading…
Reference in New Issue