Set function storage class in type determiner.

If a non-const variable in a function has a kNone storage class we
update it to kFunction. If there is a storage class other then kFunction
we emit an error.

Bug: tint:5
Change-Id: If45eb91bd0a0095e625eb1d0e1d1e361c784e35d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19102
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-08 19:58:20 +00:00 committed by dan sinclair
parent 9459dbf3ab
commit ee8ae04472
9 changed files with 150 additions and 41 deletions

View File

@ -14,6 +14,8 @@
#include "src/reader/spirv/function.h"
#include <utility>
#include "source/opt/basic_block.h"
#include "source/opt/function.h"
#include "source/opt/instruction.h"
@ -154,7 +156,8 @@ bool FunctionEmitter::EmitFunctionVariables() {
parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)));
}
// TODO(dneto): Add the initializer via Variable::set_constructor.
auto var_decl_stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
auto var_decl_stmt =
std::make_unique<ast::VariableDeclStatement>(std::move(var));
ast_body_.emplace_back(std::move(var_decl_stmt));
}
return success();

View File

@ -15,6 +15,9 @@
#ifndef SRC_READER_SPIRV_FUNCTION_H_
#define SRC_READER_SPIRV_FUNCTION_H_
#include <memory>
#include <vector>
#include "source/opt/constants.h"
#include "source/opt/function.h"
#include "source/opt/ir_context.h"
@ -50,7 +53,9 @@ class FunctionEmitter {
bool failed() const { return !success(); }
/// @returns the body of the function.
const std::vector<std::unique_ptr<ast::Statement>>& ast_body() { return ast_body_; }
const std::vector<std::unique_ptr<ast::Statement>>& ast_body() {
return ast_body_;
}
/// Records failure.
/// @returns a FailStream on which to emit diagnostics.

View File

@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/reader/spirv/function.h"
#include <string>
#include <vector>
#include "gmock/gmock.h"
#include "src/reader/spirv/function.h"
#include "src/reader/spirv/parser_impl.h"
#include "src/reader/spirv/parser_impl_test_helper.h"
#include "src/reader/spirv/spirv_tools_helpers_test.h"

View File

@ -800,11 +800,11 @@ std::unique_ptr<ast::Expression> ParserImpl::MakeConstantExpression(
std::make_unique<ast::FloatLiteral>(ast_type, spirv_const->GetFloat()));
}
if (ast_type->IsBool()) {
const bool value = spirv_const->AsNullConstant() ? false :
spirv_const->AsBoolConstant()->value();
const bool value = spirv_const->AsNullConstant()
? false
: spirv_const->AsBoolConstant()->value();
return std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(
ast_type, value));
std::make_unique<ast::BoolLiteral>(ast_type, value));
}
auto spirv_composite_const = spirv_const->AsCompositeConstant();
if (spirv_composite_const != nullptr) {

View File

@ -16,8 +16,8 @@
#define SRC_READER_SPIRV_PARSER_IMPL_TEST_HELPER_H_
#include <memory>
#include <string>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
@ -66,7 +66,8 @@ class SpvParserTest : public testing::Test {
/// Returns the string dump of a function body.
/// @param body the statement in the body
/// @returnss the string dump of a function body.
inline std::string ToString(const std::vector<std::unique_ptr<ast::Statement>>& body) {
inline std::string ToString(
const std::vector<std::unique_ptr<ast::Statement>>& body) {
std::ostringstream outs;
for (const auto& stmt : body) {
stmt->to_str(outs, 0);

View File

@ -82,7 +82,7 @@ bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
variable_stack_.push_scope();
if (!DetermineResultType(func->body())) {
if (!DetermineStatements(func->body())) {
return false;
}
variable_stack_.pop_scope();
@ -90,8 +90,12 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) {
return true;
}
bool TypeDeterminer::DetermineResultType(const ast::StatementList& stmts) {
bool TypeDeterminer::DetermineStatements(const ast::StatementList& stmts) {
for (const auto& stmt : stmts) {
if (!DetermineVariableStorageClass(stmt.get())) {
return false;
}
if (!DetermineResultType(stmt.get())) {
return false;
}
@ -99,6 +103,30 @@ bool TypeDeterminer::DetermineResultType(const ast::StatementList& stmts) {
return true;
}
bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) {
if (!stmt->IsVariableDecl()) {
return true;
}
auto var = stmt->AsVariableDecl()->variable();
// Nothing to do for const
if (var->is_const()) {
return true;
}
if (var->storage_class() == ast::StorageClass::kFunction) {
return true;
}
if (var->storage_class() != ast::StorageClass::kNone) {
error_ = "function variable has a non-function storage class";
return false;
}
var->set_storage_class(ast::StorageClass::kFunction);
return true;
}
bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
if (stmt->IsAssign()) {
auto a = stmt->AsAssign();
@ -110,7 +138,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
}
if (stmt->IsCase()) {
auto c = stmt->AsCase();
return DetermineResultType(c->body());
return DetermineStatements(c->body());
}
if (stmt->IsContinue()) {
auto c = stmt->AsContinue();
@ -119,7 +147,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
if (stmt->IsElse()) {
auto e = stmt->AsElse();
return DetermineResultType(e->condition()) &&
DetermineResultType(e->body());
DetermineStatements(e->body());
}
if (stmt->IsFallthrough()) {
return true;
@ -127,7 +155,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
if (stmt->IsIf()) {
auto i = stmt->AsIf();
if (!DetermineResultType(i->condition()) ||
!DetermineResultType(i->body())) {
!DetermineStatements(i->body())) {
return false;
}
@ -143,8 +171,8 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
}
if (stmt->IsLoop()) {
auto l = stmt->AsLoop();
return DetermineResultType(l->body()) &&
DetermineResultType(l->continuing());
return DetermineStatements(l->body()) &&
DetermineStatements(l->continuing());
}
if (stmt->IsNop()) {
return true;
@ -152,7 +180,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
if (stmt->IsRegardless()) {
auto r = stmt->AsRegardless();
return DetermineResultType(r->condition()) &&
DetermineResultType(r->body());
DetermineStatements(r->body());
}
if (stmt->IsReturn()) {
auto r = stmt->AsReturn();
@ -173,7 +201,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
if (stmt->IsUnless()) {
auto u = stmt->AsUnless();
return DetermineResultType(u->condition()) &&
DetermineResultType(u->body());
DetermineStatements(u->body());
}
if (stmt->IsVariableDecl()) {
auto v = stmt->AsVariableDecl();

View File

@ -67,7 +67,7 @@ class TypeDeterminer {
/// Determines type information for a set of statements
/// @param stmts the statements to check
/// @returns true if the determination was successful
bool DetermineResultType(const ast::StatementList& stmts);
bool DetermineStatements(const ast::StatementList& stmts);
/// Determines type information for a statement
/// @param stmt the statement to check
/// @returns true if the determination was successful
@ -76,6 +76,11 @@ class TypeDeterminer {
/// @param expr the expression to check
/// @returns true if the determination was successful
bool DetermineResultType(ast::Expression* expr);
/// Determines the storage class for variables. This assumes that it is only
/// called for things in function scope, not module scope.
/// @param stmt the statement to check
/// @returns false on error
bool DetermineVariableStorageClass(ast::Statement* stmt);
private:
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);

View File

@ -1480,5 +1480,69 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
testing::Values(ast::UnaryOp::kNegation,
ast::UnaryOp::kNot));
TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) {
ast::type::I32Type i32;
auto var =
std::make_unique<ast::Variable>("var", ast::StorageClass::kNone, &i32);
auto var_ptr = var.get();
auto stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
auto func =
std::make_unique<ast::Function>("func", ast::VariableList{}, &i32);
ast::StatementList stmts;
stmts.push_back(std::move(stmt));
func->set_body(std::move(stmts));
ast::Module m;
m.AddFunction(std::move(func));
EXPECT_TRUE(td()->Determine(&m)) << td()->error();
EXPECT_EQ(var_ptr->storage_class(), ast::StorageClass::kFunction);
}
TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) {
ast::type::I32Type i32;
auto var =
std::make_unique<ast::Variable>("var", ast::StorageClass::kNone, &i32);
var->set_is_const(true);
auto var_ptr = var.get();
auto stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
auto func =
std::make_unique<ast::Function>("func", ast::VariableList{}, &i32);
ast::StatementList stmts;
stmts.push_back(std::move(stmt));
func->set_body(std::move(stmts));
ast::Module m;
m.AddFunction(std::move(func));
EXPECT_TRUE(td()->Determine(&m)) << td()->error();
EXPECT_EQ(var_ptr->storage_class(), ast::StorageClass::kNone);
}
TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) {
ast::type::I32Type i32;
auto var = std::make_unique<ast::Variable>(
"var", ast::StorageClass::kWorkgroup, &i32);
auto stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
auto func =
std::make_unique<ast::Function>("func", ast::VariableList{}, &i32);
ast::StatementList stmts;
stmts.push_back(std::move(stmt));
func->set_body(std::move(stmts));
ast::Module m;
m.AddFunction(std::move(func));
EXPECT_FALSE(td()->Determine(&m));
EXPECT_EQ(td()->error(),
"function variable has a non-function storage class");
}
} // namespace
} // namespace tint

View File

@ -16,13 +16,13 @@
#include "gtest/gtest.h"
#include "src/ast/binary_expression.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/float_literal.h"
#include "src/ast/int_literal.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/context.h"
#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
@ -74,7 +74,8 @@ TEST_F(BuilderTest, Binary_Add_Integer_Vectors) {
std::make_unique<ast::IntLiteral>(&i32, 1)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1)));
auto lhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
auto lhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1)));
@ -82,13 +83,14 @@ TEST_F(BuilderTest, Binary_Add_Integer_Vectors) {
std::make_unique<ast::IntLiteral>(&i32, 1)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1)));
auto rhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
auto rhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
Context ctx;
TypeDeterminer td(&ctx);
ast::BinaryExpression expr(
ast::BinaryOp::kAdd, std::move(lhs), std::move(rhs));
ast::BinaryExpression expr(ast::BinaryOp::kAdd, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
@ -145,7 +147,8 @@ TEST_F(BuilderTest, Binary_Add_Float_Vectors) {
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
auto lhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
auto lhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
@ -153,13 +156,14 @@ TEST_F(BuilderTest, Binary_Add_Float_Vectors) {
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
auto rhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
auto rhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
Context ctx;
TypeDeterminer td(&ctx);
ast::BinaryExpression expr(
ast::BinaryOp::kAdd, std::move(lhs), std::move(rhs));
ast::BinaryExpression expr(ast::BinaryOp::kAdd, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();