ast: Remove types from ast::Literals

A literal has an implicit type, so there should be no type on the AST node.

This highlighted that the resolver was nto canonicalizing TypeConstructorExpression types, which has been fixed.
This required preservation of the declared type name in order for error messages to contain aliased names.

Bug: tint:724
Change-Id: I21594a3e8a0fb1b73c6c5b46a14b8664b7f28512
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49345
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
Ben Clayton
2021-04-28 13:50:43 +00:00
committed by Commit Bot service account
parent 0bf0fb9b29
commit 109b18f504
34 changed files with 236 additions and 264 deletions

View File

@@ -119,7 +119,7 @@ TEST_F(ResolverControlBlockValidationTest,
ast::CaseStatementList switch_body;
ast::CaseSelectorList csl;
csl.push_back(create<ast::UintLiteral>(ty.u32(), 1));
csl.push_back(create<ast::UintLiteral>(1u));
switch_body.push_back(create<ast::CaseStatement>(
Source{Source::Location{12, 34}}, csl, Block()));
@@ -178,12 +178,12 @@ TEST_F(ResolverControlBlockValidationTest,
ast::CaseStatementList switch_body;
ast::CaseSelectorList csl_1;
csl_1.push_back(create<ast::UintLiteral>(ty.u32(), 0));
csl_1.push_back(create<ast::UintLiteral>(0u));
switch_body.push_back(create<ast::CaseStatement>(csl_1, Block()));
ast::CaseSelectorList csl_2;
csl_2.push_back(create<ast::UintLiteral>(ty.u32(), 2));
csl_2.push_back(create<ast::UintLiteral>(ty.u32(), 2));
csl_2.push_back(create<ast::UintLiteral>(2u));
csl_2.push_back(create<ast::UintLiteral>(2u));
switch_body.push_back(create<ast::CaseStatement>(
Source{Source::Location{12, 34}}, csl_2, Block()));

View File

@@ -350,8 +350,13 @@ Resolver::VariableInfo* Resolver::Variable(
return it->second;
}
auto* ctype = Canonical(type ? type : var->declared_type());
auto* info = variable_infos_.Create(var, ctype);
if (!type) {
type = var->declared_type();
}
auto type_name = type->FriendlyName(builder_->Symbols());
auto* ctype = Canonical(type);
auto* info = variable_infos_.Create(var, ctype, type_name);
variable_to_info_.emplace(var, info);
// Resolve variable's type
@@ -1304,29 +1309,35 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
return false;
}
}
SetType(expr, type_ctor->type());
// Now that the argument types have been determined, make sure that they
// obey the constructor type rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
if (auto* vec_type = type_ctor->type()->As<sem::Vector>()) {
return ValidateVectorConstructor(vec_type, type_ctor->values());
return ValidateVectorConstructor(type_ctor, vec_type,
type_ctor->values());
}
if (auto* mat_type = type_ctor->type()->As<sem::Matrix>()) {
return ValidateMatrixConstructor(mat_type, type_ctor->values());
auto mat_typename = TypeNameOf(type_ctor);
return ValidateMatrixConstructor(type_ctor, mat_type,
type_ctor->values());
}
// TODO(crbug.com/tint/634): Validate array constructor
} else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
Mark(scalar_ctor->literal());
SetType(expr, scalar_ctor->literal()->type());
SetType(expr, TypeOf(scalar_ctor->literal()));
} else {
TINT_ICE(diagnostics_) << "unexpected constructor expression type";
}
return true;
}
bool Resolver::ValidateVectorConstructor(const sem::Vector* vec_type,
const ast::ExpressionList& values) {
bool Resolver::ValidateVectorConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type,
const ast::ExpressionList& values) {
auto* elem_type = vec_type->type()->UnwrapAll();
size_t value_cardinality_sum = 0;
for (auto* value : values) {
@@ -1337,7 +1348,7 @@ bool Resolver::ValidateVectorConstructor(const sem::Vector* vec_type,
"type in vector constructor does not match vector type: "
"expected '" +
elem_type->FriendlyName(builder_->Symbols()) + "', found '" +
value_type->FriendlyName(builder_->Symbols()) + "'",
TypeNameOf(value) + "'",
value->source());
return false;
}
@@ -1384,8 +1395,7 @@ bool Resolver::ValidateVectorConstructor(const sem::Vector* vec_type,
const Source& values_start = values[0]->source();
const Source& values_end = values[values.size() - 1]->source();
diagnostics_.add_error(
"attempted to construct '" +
vec_type->FriendlyName(builder_->Symbols()) + "' with " +
"attempted to construct '" + TypeNameOf(ctor) + "' with " +
std::to_string(value_cardinality_sum) + " component(s)",
Source::Combine(values_start, values_end));
return false;
@@ -1393,8 +1403,10 @@ bool Resolver::ValidateVectorConstructor(const sem::Vector* vec_type,
return true;
}
bool Resolver::ValidateMatrixConstructor(const sem::Matrix* matrix_type,
const ast::ExpressionList& values) {
bool Resolver::ValidateMatrixConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type,
const ast::ExpressionList& values) {
// Zero Value expression
if (values.empty()) {
return true;
@@ -1407,8 +1419,8 @@ bool Resolver::ValidateMatrixConstructor(const sem::Matrix* matrix_type,
diagnostics_.add_error(
"expected " + std::to_string(matrix_type->columns()) + " '" +
VectorPretty(matrix_type->rows(), elem_type) + "' arguments in '" +
matrix_type->FriendlyName(builder_->Symbols()) +
"' constructor, found " + std::to_string(values.size()),
TypeNameOf(ctor) + "' constructor, found " +
std::to_string(values.size()),
Source::Combine(values_start, values_end));
return false;
}
@@ -1419,13 +1431,12 @@ bool Resolver::ValidateMatrixConstructor(const sem::Matrix* matrix_type,
if (!value_vec || value_vec->size() != matrix_type->rows() ||
elem_type != value_vec->type()->UnwrapAll()) {
diagnostics_.add_error(
"expected argument type '" +
VectorPretty(matrix_type->rows(), elem_type) + "' in '" +
matrix_type->FriendlyName(builder_->Symbols()) +
"' constructor, found '" +
value_type->FriendlyName(builder_->Symbols()) + "'",
value->source());
diagnostics_.add_error("expected argument type '" +
VectorPretty(matrix_type->rows(), elem_type) +
"' in '" + TypeNameOf(ctor) +
"' constructor, found '" + TypeNameOf(value) +
"'",
value->source());
return false;
}
}
@@ -1440,12 +1451,14 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) {
// A constant is the type, but a variable is always a pointer so synthesize
// the pointer around the variable type.
if (var->declaration->is_const()) {
SetType(expr, var->type);
SetType(expr, var->type, var->type_name);
} else if (var->type->Is<sem::Pointer>()) {
SetType(expr, var->type);
SetType(expr, var->type, var->type_name);
} else {
SetType(expr, builder_->create<sem::Pointer>(
const_cast<sem::Type*>(var->type), var->storage_class));
SetType(expr,
builder_->create<sem::Pointer>(const_cast<sem::Type*>(var->type),
var->storage_class),
var->type_name);
}
var->users.push_back(expr);
@@ -1962,7 +1975,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
return true;
}
const sem::Type* Resolver::TypeOf(ast::Expression* expr) {
const sem::Type* Resolver::TypeOf(const ast::Expression* expr) {
auto it = expr_info_.find(expr);
if (it != expr_info_.end()) {
return it->second.type;
@@ -1970,12 +1983,45 @@ const sem::Type* Resolver::TypeOf(ast::Expression* expr) {
return nullptr;
}
std::string Resolver::TypeNameOf(const ast::Expression* expr) {
auto it = expr_info_.find(expr);
if (it != expr_info_.end()) {
return it->second.type_name;
}
return "";
}
const sem::Type* Resolver::TypeOf(const ast::Literal* lit) {
if (lit->Is<ast::SintLiteral>()) {
return builder_->create<sem::I32>();
}
if (lit->Is<ast::UintLiteral>()) {
return builder_->create<sem::U32>();
}
if (lit->Is<ast::FloatLiteral>()) {
return builder_->create<sem::F32>();
}
if (lit->Is<ast::BoolLiteral>()) {
return builder_->create<sem::Bool>();
}
TINT_UNREACHABLE(diagnostics_)
<< "Unhandled literal type: " << lit->TypeInfo().name;
return nullptr;
}
void Resolver::SetType(ast::Expression* expr, const sem::Type* type) {
SetType(expr, type, type->FriendlyName(builder_->Symbols()));
}
void Resolver::SetType(ast::Expression* expr,
const sem::Type* type,
const std::string& type_name) {
if (expr_info_.count(expr)) {
TINT_ICE(builder_->Diagnostics())
<< "SetType() called twice for the same expression";
}
expr_info_.emplace(expr, ExpressionInfo{type, current_statement_});
type = Canonical(type);
expr_info_.emplace(expr, ExpressionInfo{type, type_name, current_statement_});
}
void Resolver::CreateSemanticNodes() const {
@@ -2071,7 +2117,8 @@ void Resolver::CreateSemanticNodes() const {
continue;
}
sem.Add(expr,
builder_->create<sem::Expression>(expr, info.type, info.statement));
builder_->create<sem::Expression>(
const_cast<ast::Expression*>(expr), info.type, info.statement));
}
// Create semantic nodes for all structs
@@ -2488,7 +2535,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
}
for (auto* selector : case_stmt->selectors()) {
if (cond_type != selector->type()) {
if (cond_type != TypeOf(selector)) {
diagnostics_.add_error("v-0026",
"the case selector values must have the same "
"type as the selector expression.",
@@ -2729,9 +2776,11 @@ void Resolver::Mark(const ast::Node* node) {
}
Resolver::VariableInfo::VariableInfo(const ast::Variable* decl,
const sem::Type* ctype)
const sem::Type* ctype,
const std::string& tn)
: declaration(decl),
type(ctype),
type_name(tn),
storage_class(decl->declared_storage_class()) {}
Resolver::VariableInfo::~VariableInfo() = default;

View File

@@ -96,11 +96,14 @@ class Resolver {
/// Structure holding semantic information about a variable.
/// Used to build the sem::Variable nodes at the end of resolving.
struct VariableInfo {
VariableInfo(const ast::Variable* decl, const sem::Type* type);
VariableInfo(const ast::Variable* decl,
const sem::Type* type,
const std::string& type_name);
~VariableInfo();
ast::Variable const* const declaration;
sem::Type const* type;
std::string const type_name;
ast::StorageClass storage_class;
std::vector<ast::IdentifierExpression*> users;
};
@@ -125,6 +128,7 @@ class Resolver {
/// Used to build the sem::Expression nodes at the end of resolving.
struct ExpressionInfo {
sem::Type const* type;
std::string const type_name; // Declared type name
sem::Statement* statement;
};
@@ -246,14 +250,16 @@ class Resolver {
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateMatrixConstructor(const sem::Matrix* matrix_type,
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type,
const ast::ExpressionList& values);
bool ValidateParameter(const ast::Variable* param);
bool ValidateReturn(const ast::ReturnStatement* ret);
bool ValidateStructure(const sem::StructType* st);
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const ast::Variable* param);
bool ValidateVectorConstructor(const sem::Vector* vec_type,
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type,
const ast::ExpressionList& values);
/// @returns the sem::Type for the ast::Type `ty`, building it if it
@@ -303,7 +309,15 @@ class Resolver {
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
const sem::Type* TypeOf(ast::Expression* expr);
const sem::Type* TypeOf(const ast::Expression* expr);
/// @returns the declared type name of the ast::Expression `expr`
/// @param expr the type name
std::string TypeNameOf(const ast::Expression* expr);
/// @returns the semantic type of the AST literal `lit`
/// @param lit the literal
const sem::Type* TypeOf(const ast::Literal* lit);
/// Creates a sem::Expression node with the resolved type `type`, and
/// assigns this semantic node to the expression `expr`.
@@ -311,6 +325,16 @@ class Resolver {
/// @param type the resolved type
void SetType(ast::Expression* expr, const sem::Type* type);
/// Creates a sem::Expression node with the resolved type `type`, the declared
/// type name `type_name` and assigns this semantic node to the expression
/// `expr`.
/// @param expr the expression
/// @param type the resolved type
/// @param type_name the declared type name
void SetType(ast::Expression* expr,
const sem::Type* type,
const std::string& type_name);
/// Constructs a new BlockInfo with the given type and with #current_block_ as
/// its parent, assigns this to #current_block_, and then calls `callback`.
/// The original #current_block_ is restored on exit.
@@ -340,7 +364,7 @@ class Resolver {
std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<const sem::StructType*, StructInfo*> struct_info_;
std::unordered_map<const sem::Type*, const sem::Type*> type_to_canonical_;
std::unordered_set<const ast::Node*> marked_;

View File

@@ -81,7 +81,7 @@ TEST_F(ResolverTest, Stmt_Case) {
auto* assign = Assign(lhs, rhs);
auto* block = Block(assign);
ast::CaseSelectorList lit;
lit.push_back(create<ast::SintLiteral>(ty.i32(), 3));
lit.push_back(create<ast::SintLiteral>(3));
auto* cse = create<ast::CaseStatement>(lit, block);
WrapInFunction(v, cse);

View File

@@ -77,7 +77,7 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc));
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type->UnwrapAliasIfNeeded(), sc));
}
static constexpr Params from_constructor_expression_cases[] = {
@@ -173,7 +173,7 @@ TEST_P(InferTypeTest_FromCallExpression, All) {
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc));
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type->UnwrapAliasIfNeeded(), sc));
}
static constexpr Params from_call_expression_cases[] = {
Params{ty_bool_},

View File

@@ -1693,7 +1693,7 @@ TEST_F(ResolverValidationTest, Expr_Constructor_Vector_Alias_Argument_Error) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: type in vector constructor does not match vector "
"type: expected 'f32', found 'u32'");
"type: expected 'f32', found 'UnsignedInt'");
}
TEST_F(ResolverValidationTest, Expr_Constructor_Vector_Alias_Argument_Success) {
@@ -2040,7 +2040,7 @@ TEST_F(ResolverValidationTest, Expr_MatrixConstructor_ArgumentTypeAlias_Error) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected argument type 'vec2<f32>' in 'mat2x2<f32>' "
"constructor, found 'vec2<u32>'");
"constructor, found 'VectorUnsigned2'");
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {