Set function storage class in type determiner.
If a non-const variable in a function has a kNone storage class we update it to kFunction. If there is a storage class other then kFunction we emit an error. Bug: tint:5 Change-Id: If45eb91bd0a0095e625eb1d0e1d1e361c784e35d Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19102 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
9459dbf3ab
commit
ee8ae04472
|
@ -14,6 +14,8 @@
|
|||
|
||||
#include "src/reader/spirv/function.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "source/opt/basic_block.h"
|
||||
#include "source/opt/function.h"
|
||||
#include "source/opt/instruction.h"
|
||||
|
@ -154,7 +156,8 @@ bool FunctionEmitter::EmitFunctionVariables() {
|
|||
parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)));
|
||||
}
|
||||
// TODO(dneto): Add the initializer via Variable::set_constructor.
|
||||
auto var_decl_stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
|
||||
auto var_decl_stmt =
|
||||
std::make_unique<ast::VariableDeclStatement>(std::move(var));
|
||||
ast_body_.emplace_back(std::move(var_decl_stmt));
|
||||
}
|
||||
return success();
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
#ifndef SRC_READER_SPIRV_FUNCTION_H_
|
||||
#define SRC_READER_SPIRV_FUNCTION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "source/opt/constants.h"
|
||||
#include "source/opt/function.h"
|
||||
#include "source/opt/ir_context.h"
|
||||
|
@ -50,7 +53,9 @@ class FunctionEmitter {
|
|||
bool failed() const { return !success(); }
|
||||
|
||||
/// @returns the body of the function.
|
||||
const std::vector<std::unique_ptr<ast::Statement>>& ast_body() { return ast_body_; }
|
||||
const std::vector<std::unique_ptr<ast::Statement>>& ast_body() {
|
||||
return ast_body_;
|
||||
}
|
||||
|
||||
/// Records failure.
|
||||
/// @returns a FailStream on which to emit diagnostics.
|
||||
|
|
|
@ -12,12 +12,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/reader/spirv/function.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "src/reader/spirv/function.h"
|
||||
#include "src/reader/spirv/parser_impl.h"
|
||||
#include "src/reader/spirv/parser_impl_test_helper.h"
|
||||
#include "src/reader/spirv/spirv_tools_helpers_test.h"
|
||||
|
|
|
@ -800,11 +800,11 @@ std::unique_ptr<ast::Expression> ParserImpl::MakeConstantExpression(
|
|||
std::make_unique<ast::FloatLiteral>(ast_type, spirv_const->GetFloat()));
|
||||
}
|
||||
if (ast_type->IsBool()) {
|
||||
const bool value = spirv_const->AsNullConstant() ? false :
|
||||
spirv_const->AsBoolConstant()->value();
|
||||
const bool value = spirv_const->AsNullConstant()
|
||||
? false
|
||||
: spirv_const->AsBoolConstant()->value();
|
||||
return std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::BoolLiteral>(
|
||||
ast_type, value));
|
||||
std::make_unique<ast::BoolLiteral>(ast_type, value));
|
||||
}
|
||||
auto spirv_composite_const = spirv_const->AsCompositeConstant();
|
||||
if (spirv_composite_const != nullptr) {
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
#define SRC_READER_SPIRV_PARSER_IMPL_TEST_HELPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
@ -66,7 +66,8 @@ class SpvParserTest : public testing::Test {
|
|||
/// Returns the string dump of a function body.
|
||||
/// @param body the statement in the body
|
||||
/// @returnss the string dump of a function body.
|
||||
inline std::string ToString(const std::vector<std::unique_ptr<ast::Statement>>& body) {
|
||||
inline std::string ToString(
|
||||
const std::vector<std::unique_ptr<ast::Statement>>& body) {
|
||||
std::ostringstream outs;
|
||||
for (const auto& stmt : body) {
|
||||
stmt->to_str(outs, 0);
|
||||
|
|
|
@ -82,7 +82,7 @@ bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
|
|||
|
||||
bool TypeDeterminer::DetermineFunction(ast::Function* func) {
|
||||
variable_stack_.push_scope();
|
||||
if (!DetermineResultType(func->body())) {
|
||||
if (!DetermineStatements(func->body())) {
|
||||
return false;
|
||||
}
|
||||
variable_stack_.pop_scope();
|
||||
|
@ -90,8 +90,12 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineResultType(const ast::StatementList& stmts) {
|
||||
bool TypeDeterminer::DetermineStatements(const ast::StatementList& stmts) {
|
||||
for (const auto& stmt : stmts) {
|
||||
if (!DetermineVariableStorageClass(stmt.get())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!DetermineResultType(stmt.get())) {
|
||||
return false;
|
||||
}
|
||||
|
@ -99,6 +103,30 @@ bool TypeDeterminer::DetermineResultType(const ast::StatementList& stmts) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) {
|
||||
if (!stmt->IsVariableDecl()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto var = stmt->AsVariableDecl()->variable();
|
||||
// Nothing to do for const
|
||||
if (var->is_const()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (var->storage_class() == ast::StorageClass::kFunction) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (var->storage_class() != ast::StorageClass::kNone) {
|
||||
error_ = "function variable has a non-function storage class";
|
||||
return false;
|
||||
}
|
||||
|
||||
var->set_storage_class(ast::StorageClass::kFunction);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
||||
if (stmt->IsAssign()) {
|
||||
auto a = stmt->AsAssign();
|
||||
|
@ -110,7 +138,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
|||
}
|
||||
if (stmt->IsCase()) {
|
||||
auto c = stmt->AsCase();
|
||||
return DetermineResultType(c->body());
|
||||
return DetermineStatements(c->body());
|
||||
}
|
||||
if (stmt->IsContinue()) {
|
||||
auto c = stmt->AsContinue();
|
||||
|
@ -119,7 +147,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
|||
if (stmt->IsElse()) {
|
||||
auto e = stmt->AsElse();
|
||||
return DetermineResultType(e->condition()) &&
|
||||
DetermineResultType(e->body());
|
||||
DetermineStatements(e->body());
|
||||
}
|
||||
if (stmt->IsFallthrough()) {
|
||||
return true;
|
||||
|
@ -127,7 +155,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
|||
if (stmt->IsIf()) {
|
||||
auto i = stmt->AsIf();
|
||||
if (!DetermineResultType(i->condition()) ||
|
||||
!DetermineResultType(i->body())) {
|
||||
!DetermineStatements(i->body())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -143,8 +171,8 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
|||
}
|
||||
if (stmt->IsLoop()) {
|
||||
auto l = stmt->AsLoop();
|
||||
return DetermineResultType(l->body()) &&
|
||||
DetermineResultType(l->continuing());
|
||||
return DetermineStatements(l->body()) &&
|
||||
DetermineStatements(l->continuing());
|
||||
}
|
||||
if (stmt->IsNop()) {
|
||||
return true;
|
||||
|
@ -152,7 +180,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
|||
if (stmt->IsRegardless()) {
|
||||
auto r = stmt->AsRegardless();
|
||||
return DetermineResultType(r->condition()) &&
|
||||
DetermineResultType(r->body());
|
||||
DetermineStatements(r->body());
|
||||
}
|
||||
if (stmt->IsReturn()) {
|
||||
auto r = stmt->AsReturn();
|
||||
|
@ -173,7 +201,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
|||
if (stmt->IsUnless()) {
|
||||
auto u = stmt->AsUnless();
|
||||
return DetermineResultType(u->condition()) &&
|
||||
DetermineResultType(u->body());
|
||||
DetermineStatements(u->body());
|
||||
}
|
||||
if (stmt->IsVariableDecl()) {
|
||||
auto v = stmt->AsVariableDecl();
|
||||
|
|
|
@ -67,7 +67,7 @@ class TypeDeterminer {
|
|||
/// Determines type information for a set of statements
|
||||
/// @param stmts the statements to check
|
||||
/// @returns true if the determination was successful
|
||||
bool DetermineResultType(const ast::StatementList& stmts);
|
||||
bool DetermineStatements(const ast::StatementList& stmts);
|
||||
/// Determines type information for a statement
|
||||
/// @param stmt the statement to check
|
||||
/// @returns true if the determination was successful
|
||||
|
@ -76,6 +76,11 @@ class TypeDeterminer {
|
|||
/// @param expr the expression to check
|
||||
/// @returns true if the determination was successful
|
||||
bool DetermineResultType(ast::Expression* expr);
|
||||
/// Determines the storage class for variables. This assumes that it is only
|
||||
/// called for things in function scope, not module scope.
|
||||
/// @param stmt the statement to check
|
||||
/// @returns false on error
|
||||
bool DetermineVariableStorageClass(ast::Statement* stmt);
|
||||
|
||||
private:
|
||||
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
|
||||
|
|
|
@ -1480,5 +1480,69 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
|
|||
testing::Values(ast::UnaryOp::kNegation,
|
||||
ast::UnaryOp::kNot));
|
||||
|
||||
TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("var", ast::StorageClass::kNone, &i32);
|
||||
auto var_ptr = var.get();
|
||||
auto stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
|
||||
|
||||
auto func =
|
||||
std::make_unique<ast::Function>("func", ast::VariableList{}, &i32);
|
||||
ast::StatementList stmts;
|
||||
stmts.push_back(std::move(stmt));
|
||||
func->set_body(std::move(stmts));
|
||||
|
||||
ast::Module m;
|
||||
m.AddFunction(std::move(func));
|
||||
|
||||
EXPECT_TRUE(td()->Determine(&m)) << td()->error();
|
||||
EXPECT_EQ(var_ptr->storage_class(), ast::StorageClass::kFunction);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("var", ast::StorageClass::kNone, &i32);
|
||||
var->set_is_const(true);
|
||||
auto var_ptr = var.get();
|
||||
auto stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
|
||||
|
||||
auto func =
|
||||
std::make_unique<ast::Function>("func", ast::VariableList{}, &i32);
|
||||
ast::StatementList stmts;
|
||||
stmts.push_back(std::move(stmt));
|
||||
func->set_body(std::move(stmts));
|
||||
|
||||
ast::Module m;
|
||||
m.AddFunction(std::move(func));
|
||||
|
||||
EXPECT_TRUE(td()->Determine(&m)) << td()->error();
|
||||
EXPECT_EQ(var_ptr->storage_class(), ast::StorageClass::kNone);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"var", ast::StorageClass::kWorkgroup, &i32);
|
||||
auto stmt = std::make_unique<ast::VariableDeclStatement>(std::move(var));
|
||||
|
||||
auto func =
|
||||
std::make_unique<ast::Function>("func", ast::VariableList{}, &i32);
|
||||
ast::StatementList stmts;
|
||||
stmts.push_back(std::move(stmt));
|
||||
func->set_body(std::move(stmts));
|
||||
|
||||
ast::Module m;
|
||||
m.AddFunction(std::move(func));
|
||||
|
||||
EXPECT_FALSE(td()->Determine(&m));
|
||||
EXPECT_EQ(td()->error(),
|
||||
"function variable has a non-function storage class");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint
|
||||
|
|
|
@ -16,13 +16,13 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "src/ast/binary_expression.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
#include "src/ast/float_literal.h"
|
||||
#include "src/ast/int_literal.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
#include "src/context.h"
|
||||
#include "src/type_determiner.h"
|
||||
#include "src/writer/spirv/builder.h"
|
||||
|
@ -69,26 +69,28 @@ TEST_F(BuilderTest, Binary_Add_Integer_Vectors) {
|
|||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
auto lhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
auto lhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
auto rhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
auto rhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::BinaryExpression expr(
|
||||
ast::BinaryOp::kAdd, std::move(lhs), std::move(rhs));
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kAdd, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
|
@ -140,26 +142,28 @@ TEST_F(BuilderTest, Binary_Add_Float_Vectors) {
|
|||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto lhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto lhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto rhs = std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto rhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::BinaryExpression expr(
|
||||
ast::BinaryOp::kAdd, std::move(lhs), std::move(rhs));
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kAdd, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
|
|
Loading…
Reference in New Issue