Migrate to using semantic::Expression

Remove the mutable `result_type` from the ast::Expression.
Replace this with the use of semantic::Expression.

Bug: tint:390
Change-Id: I1f0eaf0dce8fde46fefe50bf2c5fe5b2e4d2d2df
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/39007
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2021-01-29 16:43:41 +00:00 committed by Commit Bot service account
parent 5c186625b6
commit 3335254c1c
20 changed files with 548 additions and 503 deletions

View File

@ -14,6 +14,9 @@
#include "src/ast/expression.h" #include "src/ast/expression.h"
#include "src/semantic/expression.h"
#include "src/semantic/info.h"
TINT_INSTANTIATE_CLASS_ID(tint::ast::Expression); TINT_INSTANTIATE_CLASS_ID(tint::ast::Expression);
namespace tint { namespace tint {
@ -25,9 +28,9 @@ Expression::Expression(Expression&&) = default;
Expression::~Expression() = default; Expression::~Expression() = default;
void Expression::set_result_type(type::Type* type) { std::string Expression::result_type_str(const semantic::Info& sem) const {
// The expression result should never be an alias or access-controlled type auto* sem_expr = sem.Get(this);
result_type_ = type->UnwrapIfNeeded(); return sem_expr ? sem_expr->Type()->type_name() : "not set";
} }
} // namespace ast } // namespace ast

View File

@ -30,18 +30,6 @@ class Expression : public Castable<Expression, Node> {
public: public:
~Expression() override; ~Expression() override;
/// Sets the resulting type of this expression
/// @param type the result type to set
void set_result_type(type::Type* type);
/// @returns the resulting type from this expression
type::Type* result_type() const { return result_type_; }
/// @returns a string representation of the result type or 'not set' if no
/// result type present
std::string result_type_str(const semantic::Info&) const {
return result_type_ ? result_type_->type_name() : "not set";
}
protected: protected:
/// Constructor /// Constructor
/// @param source the source of the expression /// @param source the source of the expression
@ -49,10 +37,13 @@ class Expression : public Castable<Expression, Node> {
/// Move constructor /// Move constructor
Expression(Expression&&); Expression(Expression&&);
/// @param sem the semantic info for the program
/// @returns a string representation of the result type or 'not set' if no
/// result type present
std::string result_type_str(const semantic::Info& sem) const;
private: private:
Expression(const Expression&) = delete; Expression(const Expression&) = delete;
type::Type* result_type_ = nullptr; // Semantic info
}; };
/// A list of expressions /// A list of expressions

View File

@ -21,6 +21,7 @@
#include "src/clone_context.h" #include "src/clone_context.h"
#include "src/demangler.h" #include "src/demangler.h"
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/semantic/expression.h"
#include "src/type_determiner.h" #include "src/type_determiner.h"
namespace tint { namespace tint {
@ -102,6 +103,11 @@ bool Program::IsValid() const {
return is_valid_; return is_valid_;
} }
type::Type* Program::TypeOf(ast::Expression* expr) const {
auto* sem = Sem().Get(expr);
return sem ? sem->Type() : nullptr;
}
std::string Program::to_str(bool demangle) const { std::string Program::to_str(bool demangle) const {
AssertNotMoved(); AssertNotMoved();
auto str = ast_->to_str(Sem()); auto str = ast_->to_str(Sem());

View File

@ -115,6 +115,12 @@ class Program {
/// information /// information
bool IsValid() const; bool IsValid() const;
/// Helper for returning the resolved semantic type of the expression `expr`.
/// @param expr the AST expression
/// @return the resolved semantic type for the expression, or nullptr if the
/// expression has no resolved type.
type::Type* TypeOf(ast::Expression* expr) const;
/// @param demangle whether to automatically demangle the symbols in the /// @param demangle whether to automatically demangle the symbols in the
/// returned string /// returned string
/// @returns a string describing this program. /// @returns a string describing this program.

View File

@ -20,6 +20,7 @@
#include "src/clone_context.h" #include "src/clone_context.h"
#include "src/demangler.h" #include "src/demangler.h"
#include "src/semantic/expression.h"
#include "src/type/struct_type.h" #include "src/type/struct_type.h"
namespace tint { namespace tint {
@ -82,6 +83,11 @@ void ProgramBuilder::AssertNotMoved() const {
assert(!moved_); assert(!moved_);
} }
type::Type* ProgramBuilder::TypeOf(ast::Expression* expr) const {
auto* sem = Sem().Get(expr);
return sem ? sem->Type() : nullptr;
}
ProgramBuilder::TypesBuilder::TypesBuilder(ProgramBuilder* pb) : builder(pb) {} ProgramBuilder::TypesBuilder::TypesBuilder(ProgramBuilder* pb) : builder(pb) {}
ast::Variable* ProgramBuilder::Var(const std::string& name, ast::Variable* ProgramBuilder::Var(const std::string& name,

View File

@ -937,6 +937,15 @@ class ProgramBuilder {
source_ = Source(loc); source_ = Source(loc);
} }
/// Helper for returning the resolved semantic type of the expression `expr`.
/// @note As the TypeDeterminator is run when the Program is built, this will
/// only be useful for the TypeDeterminer itself and tests that use their own
/// TypeDeterminer.
/// @param expr the AST expression
/// @return the resolved semantic type for the expression, or nullptr if the
/// expression has no resolved type.
type::Type* TypeOf(ast::Expression* expr) const;
/// The builder types /// The builder types
TypesBuilder ty; TypesBuilder ty;

View File

@ -44,6 +44,7 @@
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/clone_context.h" #include "src/clone_context.h"
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/semantic/expression.h"
#include "src/type/array_type.h" #include "src/type/array_type.h"
#include "src/type/matrix_type.h" #include "src/type/matrix_type.h"
#include "src/type/u32_type.h" #include "src/type/u32_type.h"
@ -70,7 +71,7 @@ ast::ArrayAccessorExpression* BoundArrayAccessors::Transform(
ast::ArrayAccessorExpression* expr, ast::ArrayAccessorExpression* expr,
CloneContext* ctx, CloneContext* ctx,
diag::List* diags) { diag::List* diags) {
auto* ret_type = expr->array()->result_type()->UnwrapAll(); auto* ret_type = ctx->src->Sem().Get(expr->array())->Type()->UnwrapAll();
if (!ret_type->Is<type::Array>() && !ret_type->Is<type::Matrix>() && if (!ret_type->Is<type::Array>() && !ret_type->Is<type::Matrix>() &&
!ret_type->Is<type::Vector>()) { !ret_type->Is<type::Vector>()) {
return nullptr; return nullptr;

View File

@ -43,6 +43,7 @@
#include "src/ast/unary_op_expression.h" #include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/semantic/expression.h"
#include "src/type/array_type.h" #include "src/type/array_type.h"
#include "src/type/bool_type.h" #include "src/type/bool_type.h"
#include "src/type/depth_texture_type.h" #include "src/type/depth_texture_type.h"
@ -308,6 +309,10 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
return true; return true;
} }
if (TypeOf(expr)) {
return true; // Already resolved
}
if (auto* a = expr->As<ast::ArrayAccessorExpression>()) { if (auto* a = expr->As<ast::ArrayAccessorExpression>()) {
return DetermineArrayAccessor(a); return DetermineArrayAccessor(a);
} }
@ -346,7 +351,7 @@ bool TypeDeterminer::DetermineArrayAccessor(
return false; return false;
} }
auto* res = expr->array()->result_type(); auto* res = TypeOf(expr->array());
auto* parent_type = res->UnwrapAll(); auto* parent_type = res->UnwrapAll();
type::Type* ret = nullptr; type::Type* ret = nullptr;
if (auto* arr = parent_type->As<type::Array>()) { if (auto* arr = parent_type->As<type::Array>()) {
@ -373,7 +378,7 @@ bool TypeDeterminer::DetermineArrayAccessor(
ret = builder_->create<type::Pointer>(ret, ast::StorageClass::kFunction); ret = builder_->create<type::Pointer>(ret, ast::StorageClass::kFunction);
} }
} }
expr->set_result_type(ret); SetType(expr, ret);
return true; return true;
} }
@ -382,7 +387,7 @@ bool TypeDeterminer::DetermineBitcast(ast::BitcastExpression* expr) {
if (!DetermineResultType(expr->expr())) { if (!DetermineResultType(expr->expr())) {
return false; return false;
} }
expr->set_result_type(expr->type()); SetType(expr, expr->type());
return true; return true;
} }
@ -420,12 +425,6 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
set_referenced_from_function_if_needed(var, false); set_referenced_from_function_if_needed(var, false);
} }
} }
// An identifier with a single name is a function call, not an import
// lookup which we can handle with the regular identifier lookup.
if (!DetermineResultType(ident)) {
return false;
}
} }
} else { } else {
if (!DetermineResultType(expr->func())) { if (!DetermineResultType(expr->func())) {
@ -433,7 +432,9 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
} }
} }
if (!expr->func()->result_type()) { if (auto* type = TypeOf(expr->func())) {
SetType(expr, type);
} else {
auto func_sym = expr->func()->As<ast::IdentifierExpression>()->symbol(); auto func_sym = expr->func()->As<ast::IdentifierExpression>()->symbol();
set_error(expr->source(), set_error(expr->source(),
"v-0005: function must be declared before use: '" + "v-0005: function must be declared before use: '" +
@ -441,7 +442,6 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
return false; return false;
} }
expr->set_result_type(expr->func()->result_type());
return true; return true;
} }
@ -530,17 +530,17 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
} }
// The result type must be the same as the type of the parameter. // The result type must be the same as the type of the parameter.
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); auto* param_type = TypeOf(expr->params()[0])->UnwrapPtrIfNeeded();
expr->func()->set_result_type(param_type); SetType(expr->func(), param_type);
return true; return true;
} }
if (ident->intrinsic() == ast::Intrinsic::kAny || if (ident->intrinsic() == ast::Intrinsic::kAny ||
ident->intrinsic() == ast::Intrinsic::kAll) { ident->intrinsic() == ast::Intrinsic::kAll) {
expr->func()->set_result_type(builder_->create<type::Bool>()); SetType(expr->func(), builder_->create<type::Bool>());
return true; return true;
} }
if (ident->intrinsic() == ast::Intrinsic::kArrayLength) { if (ident->intrinsic() == ast::Intrinsic::kArrayLength) {
expr->func()->set_result_type(builder_->create<type::U32>()); SetType(expr->func(), builder_->create<type::U32>());
return true; return true;
} }
if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) { if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) {
@ -553,12 +553,12 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
auto* bool_type = builder_->create<type::Bool>(); auto* bool_type = builder_->create<type::Bool>();
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); auto* param_type = TypeOf(expr->params()[0])->UnwrapPtrIfNeeded();
if (auto* vec = param_type->As<type::Vector>()) { if (auto* vec = param_type->As<type::Vector>()) {
expr->func()->set_result_type( SetType(expr->func(),
builder_->create<type::Vector>(bool_type, vec->size())); builder_->create<type::Vector>(bool_type, vec->size()));
} else { } else {
expr->func()->set_result_type(bool_type); SetType(expr->func(), bool_type);
} }
return true; return true;
} }
@ -566,14 +566,14 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
ast::intrinsic::TextureSignature::Parameters param; ast::intrinsic::TextureSignature::Parameters param;
auto* texture_param = expr->params()[0]; auto* texture_param = expr->params()[0];
if (!texture_param->result_type()->UnwrapAll()->Is<type::Texture>()) { if (!TypeOf(texture_param)->UnwrapAll()->Is<type::Texture>()) {
set_error(expr->source(), set_error(expr->source(),
"invalid first argument for " + "invalid first argument for " +
builder_->Symbols().NameFor(ident->symbol())); builder_->Symbols().NameFor(ident->symbol()));
return false; return false;
} }
type::Texture* texture = type::Texture* texture =
texture_param->result_type()->UnwrapAll()->As<type::Texture>(); TypeOf(texture_param)->UnwrapAll()->As<type::Texture>();
bool is_array = type::IsTextureArray(texture->dim()); bool is_array = type::IsTextureArray(texture->dim());
bool is_multisampled = texture->Is<type::MultisampledTexture>(); bool is_multisampled = texture->Is<type::MultisampledTexture>();
@ -744,12 +744,12 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
} }
} }
} }
expr->func()->set_result_type(return_type); SetType(expr->func(), return_type);
return true; return true;
} }
if (ident->intrinsic() == ast::Intrinsic::kDot) { if (ident->intrinsic() == ast::Intrinsic::kDot) {
expr->func()->set_result_type(builder_->create<type::F32>()); SetType(expr->func(), builder_->create<type::F32>());
return true; return true;
} }
if (ident->intrinsic() == ast::Intrinsic::kSelect) { if (ident->intrinsic() == ast::Intrinsic::kSelect) {
@ -762,8 +762,8 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
} }
// The result type must be the same as the type of the parameter. // The result type must be the same as the type of the parameter.
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); auto* param_type = TypeOf(expr->params()[0])->UnwrapPtrIfNeeded();
expr->func()->set_result_type(param_type); SetType(expr->func(), param_type);
return true; return true;
} }
@ -791,8 +791,7 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
std::vector<type::Type*> result_types; std::vector<type::Type*> result_types;
for (uint32_t i = 0; i < data->param_count; ++i) { for (uint32_t i = 0; i < data->param_count; ++i) {
result_types.push_back( result_types.push_back(TypeOf(expr->params()[i])->UnwrapPtrIfNeeded());
expr->params()[i]->result_type()->UnwrapPtrIfNeeded());
switch (data->data_type) { switch (data->data_type) {
case IntrinsicDataType::kFloatOrIntScalarOrVector: case IntrinsicDataType::kFloatOrIntScalarOrVector:
@ -869,18 +868,17 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
// provided. // provided.
if (ident->intrinsic() == ast::Intrinsic::kLength || if (ident->intrinsic() == ast::Intrinsic::kLength ||
ident->intrinsic() == ast::Intrinsic::kDistance) { ident->intrinsic() == ast::Intrinsic::kDistance) {
expr->func()->set_result_type( SetType(expr->func(), result_types[0]->is_float_scalar()
result_types[0]->is_float_scalar()
? result_types[0] ? result_types[0]
: result_types[0]->As<type::Vector>()->type()); : result_types[0]->As<type::Vector>()->type());
return true; return true;
} }
// The determinant returns the component type of the columns // The determinant returns the component type of the columns
if (ident->intrinsic() == ast::Intrinsic::kDeterminant) { if (ident->intrinsic() == ast::Intrinsic::kDeterminant) {
expr->func()->set_result_type(result_types[0]->As<type::Matrix>()->type()); SetType(expr->func(), result_types[0]->As<type::Matrix>()->type());
return true; return true;
} }
expr->func()->set_result_type(result_types[0]); SetType(expr->func(), result_types[0]);
return true; return true;
} }
@ -891,9 +889,9 @@ bool TypeDeterminer::DetermineConstructor(ast::ConstructorExpression* expr) {
return false; return false;
} }
} }
expr->set_result_type(ty->type()); SetType(expr, ty->type());
} else { } else {
expr->set_result_type( SetType(expr,
expr->As<ast::ScalarConstructorExpression>()->literal()->type()); expr->As<ast::ScalarConstructorExpression>()->literal()->type());
} }
return true; return true;
@ -906,12 +904,12 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
// A constant is the type, but a variable is always a pointer so synthesize // A constant is the type, but a variable is always a pointer so synthesize
// the pointer around the variable type. // the pointer around the variable type.
if (var->is_const()) { if (var->is_const()) {
expr->set_result_type(var->type()); SetType(expr, var->type());
} else if (var->type()->Is<type::Pointer>()) { } else if (var->type()->Is<type::Pointer>()) {
expr->set_result_type(var->type()); SetType(expr, var->type());
} else { } else {
expr->set_result_type( SetType(expr, builder_->create<type::Pointer>(var->type(),
builder_->create<type::Pointer>(var->type(), var->storage_class())); var->storage_class()));
} }
set_referenced_from_function_if_needed(var, true); set_referenced_from_function_if_needed(var, true);
@ -920,7 +918,7 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
auto iter = symbol_to_function_.find(symbol); auto iter = symbol_to_function_.find(symbol);
if (iter != symbol_to_function_.end()) { if (iter != symbol_to_function_.end()) {
expr->set_result_type(iter->second->return_type()); SetType(expr, iter->second->return_type());
return true; return true;
} }
@ -1091,7 +1089,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
return false; return false;
} }
auto* res = expr->structure()->result_type(); auto* res = TypeOf(expr->structure());
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
type::Type* ret = nullptr; type::Type* ret = nullptr;
@ -1143,7 +1141,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
return false; return false;
} }
expr->set_result_type(ret); SetType(expr, ret);
return true; return true;
} }
@ -1157,7 +1155,7 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() || if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() || expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
expr->IsDivide() || expr->IsModulo()) { expr->IsDivide() || expr->IsModulo()) {
expr->set_result_type(expr->lhs()->result_type()->UnwrapPtrIfNeeded()); SetType(expr, TypeOf(expr->lhs())->UnwrapPtrIfNeeded());
return true; return true;
} }
// Result type is a scalar or vector of boolean type // Result type is a scalar or vector of boolean type
@ -1165,18 +1163,17 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() || expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) { expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto* bool_type = builder_->create<type::Bool>(); auto* bool_type = builder_->create<type::Bool>();
auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); auto* param_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
type::Type* result_type = bool_type;
if (auto* vec = param_type->As<type::Vector>()) { if (auto* vec = param_type->As<type::Vector>()) {
expr->set_result_type( result_type = builder_->create<type::Vector>(bool_type, vec->size());
builder_->create<type::Vector>(bool_type, vec->size()));
} else {
expr->set_result_type(bool_type);
} }
SetType(expr, result_type);
return true; return true;
} }
if (expr->IsMultiply()) { if (expr->IsMultiply()) {
auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded(); auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
// Note, the ordering here matters. The later checks depend on the prior // Note, the ordering here matters. The later checks depend on the prior
// checks having been done. // checks having been done.
@ -1184,34 +1181,36 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
auto* rhs_mat = rhs_type->As<type::Matrix>(); auto* rhs_mat = rhs_type->As<type::Matrix>();
auto* lhs_vec = lhs_type->As<type::Vector>(); auto* lhs_vec = lhs_type->As<type::Vector>();
auto* rhs_vec = rhs_type->As<type::Vector>(); auto* rhs_vec = rhs_type->As<type::Vector>();
type::Type* result_type;
if (lhs_mat && rhs_mat) { if (lhs_mat && rhs_mat) {
expr->set_result_type(builder_->create<type::Matrix>( result_type = builder_->create<type::Matrix>(
lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns())); lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns());
} else if (lhs_mat && rhs_vec) { } else if (lhs_mat && rhs_vec) {
expr->set_result_type( result_type =
builder_->create<type::Vector>(lhs_mat->type(), lhs_mat->rows())); builder_->create<type::Vector>(lhs_mat->type(), lhs_mat->rows());
} else if (lhs_vec && rhs_mat) { } else if (lhs_vec && rhs_mat) {
expr->set_result_type( result_type =
builder_->create<type::Vector>(rhs_mat->type(), rhs_mat->columns())); builder_->create<type::Vector>(rhs_mat->type(), rhs_mat->columns());
} else if (lhs_mat) { } else if (lhs_mat) {
// matrix * scalar // matrix * scalar
expr->set_result_type(lhs_type); result_type = lhs_type;
} else if (rhs_mat) { } else if (rhs_mat) {
// scalar * matrix // scalar * matrix
expr->set_result_type(rhs_type); result_type = rhs_type;
} else if (lhs_vec && rhs_vec) { } else if (lhs_vec && rhs_vec) {
expr->set_result_type(lhs_type); result_type = lhs_type;
} else if (lhs_vec) { } else if (lhs_vec) {
// Vector * scalar // Vector * scalar
expr->set_result_type(lhs_type); result_type = lhs_type;
} else if (rhs_vec) { } else if (rhs_vec) {
// Scalar * vector // Scalar * vector
expr->set_result_type(rhs_type); result_type = rhs_type;
} else { } else {
// Scalar * Scalar // Scalar * Scalar
expr->set_result_type(lhs_type); result_type = lhs_type;
} }
SetType(expr, result_type);
return true; return true;
} }
@ -1224,7 +1223,9 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
if (!DetermineResultType(expr->expr())) { if (!DetermineResultType(expr->expr())) {
return false; return false;
} }
expr->set_result_type(expr->expr()->result_type()->UnwrapPtrIfNeeded());
auto* result_type = TypeOf(expr->expr())->UnwrapPtrIfNeeded();
SetType(expr, result_type);
return true; return true;
} }
@ -1288,4 +1289,9 @@ bool TypeDeterminer::DetermineStorageTextureSubtype(type::StorageTexture* tex) {
return false; return false;
} }
void TypeDeterminer::SetType(ast::Expression* expr, type::Type* type) const {
return builder_->Sem().Add(expr,
builder_->create<semantic::Expression>(type));
}
} // namespace tint } // namespace tint

View File

@ -21,6 +21,7 @@
#include "src/ast/module.h" #include "src/ast/module.h"
#include "src/diagnostic/diagnostic.h" #include "src/diagnostic/diagnostic.h"
#include "src/program_builder.h"
#include "src/scope_stack.h" #include "src/scope_stack.h"
#include "src/type/storage_texture_type.h" #include "src/type/storage_texture_type.h"
@ -137,6 +138,18 @@ class TypeDeterminer {
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr); bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
bool DetermineUnaryOp(ast::UnaryOpExpression* expr); bool DetermineUnaryOp(ast::UnaryOpExpression* expr);
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
type::Type* TypeOf(ast::Expression* expr) const {
return builder_->TypeOf(expr);
}
/// Creates a semantic::Expression node with the resolved type `type`, and
/// assigns this semantic node to the expression `expr`.
/// @param expr the expression
/// @param type the resolved type
void SetType(ast::Expression* expr, type::Type* type) const;
ProgramBuilder* builder_; ProgramBuilder* builder_;
std::string error_; std::string error_;
ScopeStack<ast::Variable*> variable_stack_; ScopeStack<ast::Variable*> variable_stack_;

File diff suppressed because it is too large Load Diff

View File

@ -30,6 +30,7 @@
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/uint_literal.h" #include "src/ast/uint_literal.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/semantic/expression.h"
#include "src/type/alias_type.h" #include "src/type/alias_type.h"
#include "src/type/array_type.h" #include "src/type/array_type.h"
#include "src/type/i32_type.h" #include "src/type/i32_type.h"
@ -236,8 +237,9 @@ bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) {
type::Type* func_type = current_function_->return_type(); type::Type* func_type = current_function_->return_type();
type::Void void_type; type::Void void_type;
auto* ret_type = auto* ret_type = ret->has_value()
ret->has_value() ? ret->value()->result_type()->UnwrapAll() : &void_type; ? program_->Sem().Get(ret->value())->Type()->UnwrapAll()
: &void_type;
if (func_type->type_name() != ret_type->type_name()) { if (func_type->type_name() != ret_type->type_name()) {
add_error(ret->source(), "v-000y", add_error(ret->source(), "v-000y",
@ -328,7 +330,7 @@ bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) {
return false; return false;
} }
auto* cond_type = s->condition()->result_type()->UnwrapAll(); auto* cond_type = program_->Sem().Get(s->condition())->Type()->UnwrapAll();
if (!cond_type->is_integer_scalar()) { if (!cond_type->is_integer_scalar()) {
add_error(s->condition()->source(), "v-0025", add_error(s->condition()->source(), "v-0025",
"switch statement selector expression must be of a " "switch statement selector expression must be of a "
@ -472,14 +474,14 @@ bool ValidatorImpl::ValidateAssign(const ast::AssignmentStatement* assign) {
// Pointers are not storable in WGSL, but the right-hand side must be // Pointers are not storable in WGSL, but the right-hand side must be
// storable. The raw right-hand side might be a pointer value which must be // storable. The raw right-hand side might be a pointer value which must be
// loaded (dereferenced) to provide the value to be stored. // loaded (dereferenced) to provide the value to be stored.
auto* rhs_result_type = rhs->result_type()->UnwrapAll(); auto* rhs_result_type = program_->Sem().Get(rhs)->Type()->UnwrapAll();
if (!IsStorable(rhs_result_type)) { if (!IsStorable(rhs_result_type)) {
add_error(assign->source(), "v-000x", add_error(assign->source(), "v-000x",
"invalid assignment: right-hand-side is not storable: " + "invalid assignment: right-hand-side is not storable: " +
rhs->result_type()->type_name()); program_->Sem().Get(rhs)->Type()->type_name());
return false; return false;
} }
auto* lhs_result_type = lhs->result_type()->UnwrapIfNeeded(); auto* lhs_result_type = program_->Sem().Get(lhs)->Type()->UnwrapIfNeeded();
if (auto* lhs_reference_type = As<type::Pointer>(lhs_result_type)) { if (auto* lhs_reference_type = As<type::Pointer>(lhs_result_type)) {
auto* lhs_store_type = lhs_reference_type->type()->UnwrapIfNeeded(); auto* lhs_store_type = lhs_reference_type->type()->UnwrapIfNeeded();
if (lhs_store_type != rhs_result_type) { if (lhs_store_type != rhs_result_type) {
@ -497,7 +499,7 @@ bool ValidatorImpl::ValidateAssign(const ast::AssignmentStatement* assign) {
add_error( add_error(
assign->source(), "v-000x", assign->source(), "v-000x",
"invalid assignment: left-hand-side does not reference storage: " + "invalid assignment: left-hand-side does not reference storage: " +
lhs->result_type()->type_name()); program_->Sem().Get(lhs)->Type()->type_name());
return false; return false;
} }

View File

@ -130,8 +130,8 @@ TEST_F(ValidatorTest, AssignCompatibleTypes_Pass) {
Source{Source::Location{12, 34}}, lhs, rhs); Source{Source::Location{12, 34}}, lhs, rhs);
RegisterVariable(var); RegisterVariable(var);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -153,8 +153,8 @@ TEST_F(ValidatorTest, AssignCompatibleTypesThroughAlias_Pass) {
Source{Source::Location{12, 34}}, lhs, rhs); Source{Source::Location{12, 34}}, lhs, rhs);
RegisterVariable(var); RegisterVariable(var);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -178,8 +178,8 @@ TEST_F(ValidatorTest, AssignCompatibleTypesInferRHSLoad_Pass) {
RegisterVariable(var_a); RegisterVariable(var_a);
RegisterVariable(var_b); RegisterVariable(var_b);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -203,8 +203,8 @@ TEST_F(ValidatorTest, AssignThroughPointer_Pass) {
RegisterVariable(var_a); RegisterVariable(var_a);
RegisterVariable(var_b); RegisterVariable(var_b);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -227,8 +227,8 @@ TEST_F(ValidatorTest, AssignIncompatibleTypes_Fail) {
Source{Source::Location{12, 34}}, lhs, rhs); Source{Source::Location{12, 34}}, lhs, rhs);
RegisterVariable(var); RegisterVariable(var);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -257,8 +257,8 @@ TEST_F(ValidatorTest, AssignThroughPointerWrongeStoreType_Fail) {
RegisterVariable(var_a); RegisterVariable(var_a);
RegisterVariable(var_b); RegisterVariable(var_b);
EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error(); EXPECT_TRUE(td()->DetermineResultType(assign)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -286,8 +286,8 @@ TEST_F(ValidatorTest, AssignCompatibleTypesInBlockStatement_Pass) {
}); });
EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error(); EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -313,8 +313,8 @@ TEST_F(ValidatorTest, AssignIncompatibleTypesInBlockStatement_Fail) {
}); });
EXPECT_TRUE(td()->DetermineStatements(block)) << td()->error(); EXPECT_TRUE(td()->DetermineStatements(block)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -461,8 +461,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableInnerScope_Fail) {
}); });
EXPECT_TRUE(td()->DetermineStatements(outer_body)) << td()->error(); EXPECT_TRUE(td()->DetermineStatements(outer_body)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -494,8 +494,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableOuterScope_Pass) {
}); });
EXPECT_TRUE(td()->DetermineStatements(outer_body)) << td()->error(); EXPECT_TRUE(td()->DetermineStatements(outer_body)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -559,8 +559,8 @@ TEST_F(ValidatorTest, AssignToConstant_Fail) {
}); });
EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error(); EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -592,7 +592,6 @@ TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Fail) {
AST().Functions().Add(func); AST().Functions().Add(func);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();
EXPECT_TRUE(td()->DetermineFunction(func)) << td()->error();
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -622,7 +621,6 @@ TEST_F(ValidatorTest, RedeclaredIndentifier_Fail) {
AST().Functions().Add(func); AST().Functions().Add(func);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();
EXPECT_TRUE(td()->DetermineFunction(func)) << td()->error();
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();
@ -747,8 +745,8 @@ TEST_F(ValidatorTest, VariableDeclNoConstructor_Pass) {
}); });
EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error(); EXPECT_TRUE(td()->DetermineStatements(body)) << td()->error();
ASSERT_NE(lhs->result_type(), nullptr); ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(rhs->result_type(), nullptr); ASSERT_NE(TypeOf(rhs), nullptr);
ValidatorImpl& v = Build(); ValidatorImpl& v = Build();

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/semantic/expression.h"
#include "src/type/void_type.h" #include "src/type/void_type.h"
#include "src/type_determiner.h" #include "src/type_determiner.h"
#include "src/validator/validator_impl.h" #include "src/validator/validator_impl.h"

View File

@ -18,6 +18,7 @@
#include "src/ast/expression.h" #include "src/ast/expression.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
#include "src/semantic/expression.h"
#include "src/semantic/info.h" #include "src/semantic/info.h"
#include "src/type/vector_type.h" #include "src/type/vector_type.h"
@ -42,21 +43,18 @@ ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b,
ast::Expression* scalar) { ast::Expression* scalar) {
uint32_t packed_size; uint32_t packed_size;
type::Type* packed_el_ty; // Currently must be f32. type::Type* packed_el_ty; // Currently must be f32.
if (auto* vec = vector->result_type()->As<type::Vector>()) { auto* vector_sem = b->Sem().Get(vector);
if (auto* vec = vector_sem->Type()->As<type::Vector>()) {
packed_size = vec->size() + 1; packed_size = vec->size() + 1;
packed_el_ty = vec->type(); packed_el_ty = vec->type();
} else { } else {
packed_size = 2; packed_size = 2;
packed_el_ty = vector->result_type(); packed_el_ty = vector_sem->Type();
}
if (!packed_el_ty) {
return nullptr; // missing type info
} }
// Cast scalar to the vector element type // Cast scalar to the vector element type
auto* scalar_cast = b->Construct(packed_el_ty, scalar); auto* scalar_cast = b->Construct(packed_el_ty, scalar);
scalar_cast->set_result_type(packed_el_ty); b->Sem().Add(scalar_cast, b->create<semantic::Expression>(packed_el_ty));
auto* packed_ty = b->create<type::Vector>(packed_el_ty, packed_size); auto* packed_ty = b->create<type::Vector>(packed_el_ty, packed_size);
@ -68,14 +66,14 @@ ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b,
} else { } else {
packed.emplace_back(vector); packed.emplace_back(vector);
} }
if (packed_el_ty != scalar->result_type()) { if (packed_el_ty != b->Sem().Get(scalar)->Type()) {
packed.emplace_back(scalar_cast); packed.emplace_back(scalar_cast);
} else { } else {
packed.emplace_back(scalar); packed.emplace_back(scalar);
} }
auto* constructor = b->Construct(packed_ty, std::move(packed)); auto* constructor = b->Construct(packed_ty, std::move(packed));
constructor->set_result_type(packed_ty); b->Sem().Add(constructor, b->create<semantic::Expression>(packed_ty));
return constructor; return constructor;
} }

View File

@ -45,6 +45,7 @@
#include "src/ast/variable.h" #include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/semantic/expression.h"
#include "src/type/access_control_type.h" #include "src/type/access_control_type.h"
#include "src/type/alias_type.h" #include "src/type/alias_type.h"
#include "src/type/array_type.h" #include "src/type/array_type.h"
@ -383,8 +384,8 @@ bool GeneratorImpl::EmitBinary(std::ostream& pre,
return true; return true;
} }
auto* lhs_type = expr->lhs()->result_type()->UnwrapAll(); auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
auto* rhs_type = expr->rhs()->result_type()->UnwrapAll(); auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
// Multiplying by a matrix requires the use of `mul` in order to get the // Multiplying by a matrix requires the use of `mul` in order to get the
// type of multiply we desire. // type of multiply we desire.
if (expr->op() == ast::BinaryOp::kMultiply && if (expr->op() == ast::BinaryOp::kMultiply &&
@ -692,7 +693,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& pre,
auto const kNotUsed = ast::intrinsic::TextureSignature::Parameters::kNotUsed; auto const kNotUsed = ast::intrinsic::TextureSignature::Parameters::kNotUsed;
auto* texture = params[pidx.texture]; auto* texture = params[pidx.texture];
auto* texture_type = texture->result_type()->UnwrapAll()->As<type::Texture>(); auto* texture_type = TypeOf(texture)->UnwrapAll()->As<type::Texture>();
switch (ident->intrinsic()) { switch (ident->intrinsic()) {
case ast::Intrinsic::kTextureDimensions: case ast::Intrinsic::kTextureDimensions:
@ -887,7 +888,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& pre,
auto emit_vector_appended_with_i32_zero = [&](tint::ast::Expression* vector) { auto emit_vector_appended_with_i32_zero = [&](tint::ast::Expression* vector) {
auto* i32 = builder_.create<type::I32>(); auto* i32 = builder_.create<type::I32>();
auto* zero = builder_.Expr(0); auto* zero = builder_.Expr(0);
zero->set_result_type(i32); builder_.Sem().Add(zero, builder_.create<semantic::Expression>(i32));
auto* packed = AppendVector(&builder_, vector, zero); auto* packed = AppendVector(&builder_, vector, zero);
return EmitExpression(pre, out, packed); return EmitExpression(pre, out, packed);
}; };
@ -1857,7 +1858,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
} }
first = false; first = false;
if (auto* mem = expr->As<ast::MemberAccessorExpression>()) { if (auto* mem = expr->As<ast::MemberAccessorExpression>()) {
auto* res_type = mem->structure()->result_type()->UnwrapAll(); auto* res_type = TypeOf(mem->structure())->UnwrapAll();
if (auto* str = res_type->As<type::Struct>()) { if (auto* str = res_type->As<type::Struct>()) {
auto* str_type = str->impl(); auto* str_type = str->impl();
auto* str_member = str_type->get_member(mem->member()->symbol()); auto* str_member = str_type->get_member(mem->member()->symbol());
@ -1895,7 +1896,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
expr = mem->structure(); expr = mem->structure();
} else if (auto* ary = expr->As<ast::ArrayAccessorExpression>()) { } else if (auto* ary = expr->As<ast::ArrayAccessorExpression>()) {
auto* ary_type = ary->array()->result_type()->UnwrapAll(); auto* ary_type = TypeOf(ary->array())->UnwrapAll();
out << "("; out << "(";
if (auto* arr = ary_type->As<type::Array>()) { if (auto* arr = ary_type->As<type::Array>()) {
@ -1942,7 +1943,7 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre,
std::ostream& out, std::ostream& out,
ast::Expression* expr, ast::Expression* expr,
ast::Expression* rhs) { ast::Expression* rhs) {
auto* result_type = expr->result_type()->UnwrapAll(); auto* result_type = TypeOf(expr)->UnwrapAll();
bool is_store = rhs != nullptr; bool is_store = rhs != nullptr;
std::string access_method = is_store ? "Store" : "Load"; std::string access_method = is_store ? "Store" : "Load";
@ -2058,7 +2059,7 @@ bool GeneratorImpl::is_storage_buffer_access(
bool GeneratorImpl::is_storage_buffer_access( bool GeneratorImpl::is_storage_buffer_access(
ast::MemberAccessorExpression* expr) { ast::MemberAccessorExpression* expr) {
auto* structure = expr->structure(); auto* structure = expr->structure();
auto* data_type = structure->result_type()->UnwrapAll(); auto* data_type = TypeOf(structure)->UnwrapAll();
// TODO(dsinclair): Swizzle // TODO(dsinclair): Swizzle
// //
// If the data is a multi-element swizzle then we will not load the swizzle // If the data is a multi-element swizzle then we will not load the swizzle

View File

@ -390,6 +390,12 @@ class GeneratorImpl {
std::string current_ep_var_name(VarType type); std::string current_ep_var_name(VarType type);
std::string get_buffer_name(ast::Expression* expr); std::string get_buffer_name(ast::Expression* expr);
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
type::Type* TypeOf(ast::Expression* expr) const {
return builder_.TypeOf(expr);
}
std::string error_; std::string error_;
size_t indent_ = 0; size_t indent_ = 0;

View File

@ -50,6 +50,7 @@
#include "src/ast/variable.h" #include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/program.h" #include "src/program.h"
#include "src/semantic/expression.h"
#include "src/type/access_control_type.h" #include "src/type/access_control_type.h"
#include "src/type/alias_type.h" #include "src/type/alias_type.h"
#include "src/type/array_type.h" #include "src/type/array_type.h"
@ -613,7 +614,7 @@ bool GeneratorImpl::EmitTextureCall(ast::CallExpression* expr) {
assert(pidx.texture != kNotUsed); assert(pidx.texture != kNotUsed);
auto* texture_type = auto* texture_type =
params[pidx.texture]->result_type()->UnwrapAll()->As<type::Texture>(); TypeOf(params[pidx.texture])->UnwrapAll()->As<type::Texture>();
switch (ident->intrinsic()) { switch (ident->intrinsic()) {
case ast::Intrinsic::kTextureDimensions: { case ast::Intrinsic::kTextureDimensions: {
@ -658,7 +659,7 @@ bool GeneratorImpl::EmitTextureCall(ast::CallExpression* expr) {
get_dim(dims[0]); get_dim(dims[0]);
out_ << ")"; out_ << ")";
} else { } else {
EmitType(expr->result_type(), ""); EmitType(TypeOf(expr), "");
out_ << "("; out_ << "(";
for (size_t i = 0; i < dims.size(); i++) { for (size_t i = 0; i < dims.size(); i++) {
if (i > 0) { if (i > 0) {
@ -764,8 +765,7 @@ bool GeneratorImpl::EmitTextureCall(ast::CallExpression* expr) {
} }
} }
if (pidx.ddx != kNotUsed) { if (pidx.ddx != kNotUsed) {
auto dim = params[pidx.texture] auto dim = TypeOf(params[pidx.texture])
->result_type()
->UnwrapPtrIfNeeded() ->UnwrapPtrIfNeeded()
->As<type::Texture>() ->As<type::Texture>()
->dim(); ->dim();
@ -815,6 +815,7 @@ bool GeneratorImpl::EmitTextureCall(ast::CallExpression* expr) {
std::string GeneratorImpl::generate_builtin_name( std::string GeneratorImpl::generate_builtin_name(
ast::IdentifierExpression* ident) { ast::IdentifierExpression* ident) {
auto* type = TypeOf(ident);
std::string out = "metal::"; std::string out = "metal::";
switch (ident->intrinsic()) { switch (ident->intrinsic()) {
case ast::Intrinsic::kAcos: case ast::Intrinsic::kAcos:
@ -852,26 +853,23 @@ std::string GeneratorImpl::generate_builtin_name(
out += program_->Symbols().NameFor(ident->symbol()); out += program_->Symbols().NameFor(ident->symbol());
break; break;
case ast::Intrinsic::kAbs: case ast::Intrinsic::kAbs:
if (ident->result_type()->Is<type::F32>()) { if (type->Is<type::F32>()) {
out += "fabs"; out += "fabs";
} else if (ident->result_type()->Is<type::U32>() || } else if (type->Is<type::U32>() || type->Is<type::I32>()) {
ident->result_type()->Is<type::I32>()) {
out += "abs"; out += "abs";
} }
break; break;
case ast::Intrinsic::kMax: case ast::Intrinsic::kMax:
if (ident->result_type()->Is<type::F32>()) { if (type->Is<type::F32>()) {
out += "fmax"; out += "fmax";
} else if (ident->result_type()->Is<type::U32>() || } else if (type->Is<type::U32>() || type->Is<type::I32>()) {
ident->result_type()->Is<type::I32>()) {
out += "max"; out += "max";
} }
break; break;
case ast::Intrinsic::kMin: case ast::Intrinsic::kMin:
if (ident->result_type()->Is<type::F32>()) { if (type->Is<type::F32>()) {
out += "fmin"; out += "fmin";
} else if (ident->result_type()->Is<type::U32>() || } else if (type->Is<type::U32>() || type->Is<type::I32>()) {
ident->result_type()->Is<type::I32>()) {
out += "min"; out += "min";
} }
break; break;

View File

@ -280,6 +280,12 @@ class GeneratorImpl : public TextGenerator {
std::string current_ep_var_name(VarType type); std::string current_ep_var_name(VarType type);
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
type::Type* TypeOf(ast::Expression* expr) const {
return program_->TypeOf(expr);
}
Namer namer_; Namer namer_;
ScopeStack<ast::Variable*> global_variables_; ScopeStack<ast::Variable*> global_variables_;
Symbol current_ep_sym_; Symbol current_ep_sym_;

View File

@ -59,6 +59,7 @@
#include "src/ast/variable.h" #include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/program.h" #include "src/program.h"
#include "src/semantic/expression.h"
#include "src/type/access_control_type.h" #include "src/type/access_control_type.h"
#include "src/type/alias_type.h" #include "src/type/alias_type.h"
#include "src/type/array_type.h" #include "src/type/array_type.h"
@ -405,7 +406,8 @@ bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) {
} }
// If the thing we're assigning is a pointer then we must load it first. // If the thing we're assigning is a pointer then we must load it first.
rhs_id = GenerateLoadIfNeeded(assign->rhs()->result_type(), rhs_id); auto* type = TypeOf(assign->rhs());
rhs_id = GenerateLoadIfNeeded(type, rhs_id);
return GenerateStore(lhs_id, rhs_id); return GenerateStore(lhs_id, rhs_id);
} }
@ -639,7 +641,8 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) {
if (init_id == 0) { if (init_id == 0) {
return false; return false;
} }
init_id = GenerateLoadIfNeeded(var->constructor()->result_type(), init_id); auto* type = TypeOf(var->constructor());
init_id = GenerateLoadIfNeeded(type, init_id);
} }
if (var->is_const()) { if (var->is_const()) {
@ -843,7 +846,8 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
if (idx_id == 0) { if (idx_id == 0) {
return 0; return 0;
} }
idx_id = GenerateLoadIfNeeded(expr->idx_expr()->result_type(), idx_id); auto* type = TypeOf(expr->idx_expr());
idx_id = GenerateLoadIfNeeded(type, idx_id);
// If the source is a pointer we access chain into it. We also access chain // If the source is a pointer we access chain into it. We also access chain
// into an array of non-scalar types. // into an array of non-scalar types.
@ -851,11 +855,11 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
(info->source_type->Is<type::Array>() && (info->source_type->Is<type::Array>() &&
!info->source_type->As<type::Array>()->type()->is_scalar())) { !info->source_type->As<type::Array>()->type()->is_scalar())) {
info->access_chain_indices.push_back(idx_id); info->access_chain_indices.push_back(idx_id);
info->source_type = expr->result_type(); info->source_type = TypeOf(expr);
return true; return true;
} }
auto result_type_id = GenerateTypeIfNeeded(expr->result_type()); auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) { if (result_type_id == 0) {
return false; return false;
} }
@ -872,7 +876,7 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
} }
info->source_id = extract_id; info->source_id = extract_id;
info->source_type = expr->result_type(); info->source_type = TypeOf(expr);
return true; return true;
} }
@ -880,7 +884,8 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
AccessorInfo* info) { AccessorInfo* info) {
auto* data_type = auto* data_type =
expr->structure()->result_type()->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); TypeOf(expr->structure())->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
auto* expr_type = TypeOf(expr);
// If the data_type is a structure we're accessing a member, if it's a // If the data_type is a structure we're accessing a member, if it's a
// vector we're accessing a swizzle. // vector we're accessing a swizzle.
@ -908,7 +913,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
return 0; return 0;
} }
info->access_chain_indices.push_back(idx_id); info->access_chain_indices.push_back(idx_id);
info->source_type = expr->result_type(); info->source_type = expr_type;
return true; return true;
} }
@ -934,7 +939,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
} }
info->access_chain_indices.push_back(idx_id); info->access_chain_indices.push_back(idx_id);
} else { } else {
auto result_type_id = GenerateTypeIfNeeded(expr->result_type()); auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (result_type_id == 0) { if (result_type_id == 0) {
return 0; return 0;
} }
@ -949,7 +954,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
} }
info->source_id = extract_id; info->source_id = extract_id;
info->source_type = expr->result_type(); info->source_type = expr_type;
} }
return true; return true;
} }
@ -977,12 +982,12 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
return false; return false;
} }
info->source_id = GenerateLoadIfNeeded(expr->result_type(), extract_id); info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
info->source_type = expr->result_type()->UnwrapPtrIfNeeded(); info->source_type = expr_type->UnwrapPtrIfNeeded();
info->access_chain_indices.clear(); info->access_chain_indices.clear();
} }
auto result_type_id = GenerateTypeIfNeeded(expr->result_type()); auto result_type_id = GenerateTypeIfNeeded(expr_type);
if (result_type_id == 0) { if (result_type_id == 0) {
return false; return false;
} }
@ -1009,7 +1014,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
return false; return false;
} }
info->source_id = result_id; info->source_id = result_id;
info->source_type = expr->result_type(); info->source_type = expr_type;
return true; return true;
} }
@ -1040,13 +1045,13 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
if (info.source_id == 0) { if (info.source_id == 0) {
return 0; return 0;
} }
info.source_type = source->result_type(); info.source_type = TypeOf(source);
// If our initial access is into an array of non-scalar types, and that array // If our initial access is into an array of non-scalar types, and that array
// is not a pointer, then we need to load that array into a variable in order // is not a pointer, then we need to load that array into a variable in order
// to access chain into the array. // to access chain into the array.
if (auto* array = accessors[0]->As<ast::ArrayAccessorExpression>()) { if (auto* array = accessors[0]->As<ast::ArrayAccessorExpression>()) {
auto* ary_res_type = array->array()->result_type(); auto* ary_res_type = TypeOf(array->array());
if (!ary_res_type->Is<type::Pointer>() && if (!ary_res_type->Is<type::Pointer>() &&
(ary_res_type->Is<type::Array>() && (ary_res_type->Is<type::Array>() &&
@ -1095,7 +1100,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
} }
if (!info.access_chain_indices.empty()) { if (!info.access_chain_indices.empty()) {
auto result_type_id = GenerateTypeIfNeeded(expr->result_type()); auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) { if (result_type_id == 0) {
return 0; return 0;
} }
@ -1153,16 +1158,16 @@ uint32_t Builder::GenerateUnaryOpExpression(ast::UnaryOpExpression* expr) {
if (val_id == 0) { if (val_id == 0) {
return 0; return 0;
} }
val_id = GenerateLoadIfNeeded(expr->expr()->result_type(), val_id); val_id = GenerateLoadIfNeeded(TypeOf(expr->expr()), val_id);
auto type_id = GenerateTypeIfNeeded(expr->result_type()); auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) { if (type_id == 0) {
return 0; return 0;
} }
spv::Op op = spv::Op::OpNop; spv::Op op = spv::Op::OpNop;
if (expr->op() == ast::UnaryOp::kNegation) { if (expr->op() == ast::UnaryOp::kNegation) {
if (expr->result_type()->is_float_scalar_or_vector()) { if (TypeOf(expr)->is_float_scalar_or_vector()) {
op = spv::Op::OpFNegate; op = spv::Op::OpFNegate;
} else { } else {
op = spv::Op::OpSNegate; op = spv::Op::OpSNegate;
@ -1260,7 +1265,7 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) {
} else if (auto* str = subtype->As<type::Struct>()) { } else if (auto* str = subtype->As<type::Struct>()) {
subtype = str->impl()->members()[i]->type()->UnwrapAll(); subtype = str->impl()->members()[i]->type()->UnwrapAll();
} }
if (subtype != sc->result_type()->UnwrapAll()) { if (subtype != TypeOf(sc)->UnwrapAll()) {
return false; return false;
} }
} }
@ -1291,7 +1296,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
if (auto* res_vec = result_type->As<type::Vector>()) { if (auto* res_vec = result_type->As<type::Vector>()) {
if (res_vec->type()->is_scalar()) { if (res_vec->type()->is_scalar()) {
auto* value_type = values[0]->result_type()->UnwrapAll(); auto* value_type = TypeOf(values[0])->UnwrapAll();
if (auto* val_vec = value_type->As<type::Vector>()) { if (auto* val_vec = value_type->As<type::Vector>()) {
if (val_vec->type()->is_scalar()) { if (val_vec->type()->is_scalar()) {
can_cast_or_copy = res_vec->size() == val_vec->size(); can_cast_or_copy = res_vec->size() == val_vec->size();
@ -1324,13 +1329,13 @@ uint32_t Builder::GenerateTypeConstructorExpression(
nullptr, e->As<ast::ConstructorExpression>(), is_global_init); nullptr, e->As<ast::ConstructorExpression>(), is_global_init);
} else { } else {
id = GenerateExpression(e); id = GenerateExpression(e);
id = GenerateLoadIfNeeded(e->result_type(), id); id = GenerateLoadIfNeeded(TypeOf(e), id);
} }
if (id == 0) { if (id == 0) {
return 0; return 0;
} }
auto* value_type = e->result_type()->UnwrapPtrIfNeeded(); auto* value_type = TypeOf(e)->UnwrapPtrIfNeeded();
// If the result and value types are the same we can just use the object. // If the result and value types are the same we can just use the object.
// If the result is not a vector then we should have validated that the // If the result is not a vector then we should have validated that the
// value type is a correctly sized vector so we can just use it directly. // value type is a correctly sized vector so we can just use it directly.
@ -1445,9 +1450,9 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(type::Type* to_type,
if (val_id == 0) { if (val_id == 0) {
return 0; return 0;
} }
val_id = GenerateLoadIfNeeded(from_expr->result_type(), val_id); val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
auto* from_type = from_expr->result_type()->UnwrapPtrIfNeeded(); auto* from_type = TypeOf(from_expr)->UnwrapPtrIfNeeded();
spv::Op op = spv::Op::OpNop; spv::Op op = spv::Op::OpNop;
if ((from_type->Is<type::I32>() && to_type->Is<type::F32>()) || if ((from_type->Is<type::I32>() && to_type->Is<type::F32>()) ||
@ -1557,13 +1562,13 @@ uint32_t Builder::GenerateShortCircuitBinaryExpression(
if (lhs_id == 0) { if (lhs_id == 0) {
return false; return false;
} }
lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id); lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs()), lhs_id);
// Get the ID of the basic block where control flow will diverge. It's the // Get the ID of the basic block where control flow will diverge. It's the
// last basic block generated for the left-hand-side of the operator. // last basic block generated for the left-hand-side of the operator.
auto original_label_id = current_label_id_; auto original_label_id = current_label_id_;
auto type_id = GenerateTypeIfNeeded(expr->result_type()); auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) { if (type_id == 0) {
return 0; return 0;
} }
@ -1601,7 +1606,7 @@ uint32_t Builder::GenerateShortCircuitBinaryExpression(
if (rhs_id == 0) { if (rhs_id == 0) {
return 0; return 0;
} }
rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id); rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs()), rhs_id);
// Get the block ID of the last basic block generated for the right-hand-side // Get the block ID of the last basic block generated for the right-hand-side
// expression. That block will be an immediate predecessor to the merge block. // expression. That block will be an immediate predecessor to the merge block.
@ -1638,26 +1643,26 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
if (lhs_id == 0) { if (lhs_id == 0) {
return 0; return 0;
} }
lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id); lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs()), lhs_id);
auto rhs_id = GenerateExpression(expr->rhs()); auto rhs_id = GenerateExpression(expr->rhs());
if (rhs_id == 0) { if (rhs_id == 0) {
return 0; return 0;
} }
rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id); rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs()), rhs_id);
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
auto type_id = GenerateTypeIfNeeded(expr->result_type()); auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (type_id == 0) { if (type_id == 0) {
return 0; return 0;
} }
// Handle int and float and the vectors of those types. Other types // Handle int and float and the vectors of those types. Other types
// should have been rejected by validation. // should have been rejected by validation.
auto* lhs_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
auto* rhs_type = expr->rhs()->result_type()->UnwrapPtrIfNeeded(); auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector(); bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
@ -1806,7 +1811,7 @@ uint32_t Builder::GenerateCallExpression(ast::CallExpression* expr) {
return GenerateIntrinsic(ident, expr); return GenerateIntrinsic(ident, expr);
} }
auto type_id = GenerateTypeIfNeeded(expr->func()->result_type()); auto type_id = GenerateTypeIfNeeded(TypeOf(expr->func()));
if (type_id == 0) { if (type_id == 0) {
return 0; return 0;
} }
@ -1829,7 +1834,7 @@ uint32_t Builder::GenerateCallExpression(ast::CallExpression* expr) {
if (id == 0) { if (id == 0) {
return 0; return 0;
} }
id = GenerateLoadIfNeeded(param->result_type(), id); id = GenerateLoadIfNeeded(TypeOf(param), id);
ops.push_back(Operand::Int(id)); ops.push_back(Operand::Int(id));
} }
@ -1845,7 +1850,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident,
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
auto result_type_id = GenerateTypeIfNeeded(call->result_type()); auto result_type_id = GenerateTypeIfNeeded(TypeOf(call));
if (result_type_id == 0) { if (result_type_id == 0) {
return 0; return 0;
} }
@ -1895,7 +1900,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident,
} }
params.push_back(Operand::Int(struct_id)); params.push_back(Operand::Int(struct_id));
auto* type = accessor->structure()->result_type()->UnwrapAll(); auto* type = TypeOf(accessor->structure())->UnwrapAll();
if (!type->Is<type::Struct>()) { if (!type->Is<type::Struct>()) {
error_ = error_ =
"invalid type (" + type->type_name() + ") for runtime array length"; "invalid type (" + type->type_name() + ") for runtime array length";
@ -1948,8 +1953,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident,
return 0; return 0;
} }
auto set_id = set_iter->second; auto set_id = set_iter->second;
auto inst_id = auto inst_id = intrinsic_to_glsl_method(TypeOf(ident), ident->intrinsic());
intrinsic_to_glsl_method(ident->result_type(), ident->intrinsic());
if (inst_id == 0) { if (inst_id == 0) {
error_ = "unknown method " + builder_.Symbols().NameFor(ident->symbol()); error_ = "unknown method " + builder_.Symbols().NameFor(ident->symbol());
return 0; return 0;
@ -1972,7 +1976,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident,
if (val_id == 0) { if (val_id == 0) {
return false; return false;
} }
val_id = GenerateLoadIfNeeded(p->result_type(), val_id); val_id = GenerateLoadIfNeeded(TypeOf(p), val_id);
params.emplace_back(Operand::Int(val_id)); params.emplace_back(Operand::Int(val_id));
} }
@ -1995,10 +1999,8 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident,
auto const kNotUsed = ast::intrinsic::TextureSignature::Parameters::kNotUsed; auto const kNotUsed = ast::intrinsic::TextureSignature::Parameters::kNotUsed;
assert(pidx.texture != kNotUsed); assert(pidx.texture != kNotUsed);
auto* texture_type = call->params()[pidx.texture] auto* texture_type =
->result_type() TypeOf(call->params()[pidx.texture])->UnwrapAll()->As<type::Texture>();
->UnwrapAll()
->As<type::Texture>();
auto op = spv::Op::OpNop; auto op = spv::Op::OpNop;
@ -2008,7 +2010,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident,
if (val_id == 0) { if (val_id == 0) {
return Operand::Int(0); return Operand::Int(0);
} }
val_id = GenerateLoadIfNeeded(p->result_type(), val_id); val_id = GenerateLoadIfNeeded(TypeOf(p), val_id);
return Operand::Int(val_id); return Operand::Int(val_id);
}; };
@ -2076,7 +2078,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident,
} else { } else {
// Assign post_emission to swizzle the result of the call to // Assign post_emission to swizzle the result of the call to
// OpImageQuerySize[Lod]. // OpImageQuerySize[Lod].
auto* element_type = ElementTypeOf(call->result_type()); auto* element_type = ElementTypeOf(TypeOf(call));
auto spirv_result = result_op(); auto spirv_result = result_op();
auto* spirv_result_type = auto* spirv_result_type =
builder_.create<type::Vector>(element_type, spirv_result_width); builder_.create<type::Vector>(element_type, spirv_result_width);
@ -2302,7 +2304,7 @@ bool Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident,
} }
assert(pidx.level != kNotUsed); assert(pidx.level != kNotUsed);
auto level = Operand::Int(0); auto level = Operand::Int(0);
if (call->params()[pidx.level]->result_type()->Is<type::I32>()) { if (TypeOf(call->params()[pidx.level])->Is<type::I32>()) {
// Depth textures have i32 parameters for the level, but SPIR-V expects // Depth textures have i32 parameters for the level, but SPIR-V expects
// F32. Cast. // F32. Cast.
auto* f32 = builder_.create<type::F32>(); auto* f32 = builder_.create<type::F32>();
@ -2417,7 +2419,7 @@ uint32_t Builder::GenerateBitcastExpression(ast::BitcastExpression* expr) {
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
auto result_type_id = GenerateTypeIfNeeded(expr->result_type()); auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
if (result_type_id == 0) { if (result_type_id == 0) {
return 0; return 0;
} }
@ -2426,11 +2428,11 @@ uint32_t Builder::GenerateBitcastExpression(ast::BitcastExpression* expr) {
if (val_id == 0) { if (val_id == 0) {
return 0; return 0;
} }
val_id = GenerateLoadIfNeeded(expr->expr()->result_type(), val_id); val_id = GenerateLoadIfNeeded(TypeOf(expr->expr()), val_id);
// Bitcast does not allow same types, just emit a CopyObject // Bitcast does not allow same types, just emit a CopyObject
auto* to_type = expr->result_type()->UnwrapPtrIfNeeded(); auto* to_type = TypeOf(expr)->UnwrapPtrIfNeeded();
auto* from_type = expr->expr()->result_type()->UnwrapPtrIfNeeded(); auto* from_type = TypeOf(expr->expr())->UnwrapPtrIfNeeded();
if (to_type->type_name() == from_type->type_name()) { if (to_type->type_name() == from_type->type_name()) {
if (!push_function_inst( if (!push_function_inst(
spv::Op::OpCopyObject, spv::Op::OpCopyObject,
@ -2457,7 +2459,7 @@ bool Builder::GenerateConditionalBlock(
if (cond_id == 0) { if (cond_id == 0) {
return false; return false;
} }
cond_id = GenerateLoadIfNeeded(cond->result_type(), cond_id); cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id);
auto merge_block = result_op(); auto merge_block = result_op();
auto merge_block_id = merge_block.to_i(); auto merge_block_id = merge_block.to_i();
@ -2545,7 +2547,7 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) {
if (cond_id == 0) { if (cond_id == 0) {
return false; return false;
} }
cond_id = GenerateLoadIfNeeded(stmt->condition()->result_type(), cond_id); cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition()), cond_id);
auto default_block = result_op(); auto default_block = result_op();
auto default_block_id = default_block.to_i(); auto default_block_id = default_block.to_i();
@ -2641,7 +2643,7 @@ bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) {
if (val_id == 0) { if (val_id == 0) {
return false; return false;
} }
val_id = GenerateLoadIfNeeded(stmt->value()->result_type(), val_id); val_id = GenerateLoadIfNeeded(TypeOf(stmt->value()), val_id);
if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) { if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) {
return false; return false;
} }

View File

@ -488,6 +488,12 @@ class Builder {
/// automatically. /// automatically.
Operand result_op(); Operand result_op();
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
type::Type* TypeOf(ast::Expression* expr) const {
return builder_.TypeOf(expr);
}
ProgramBuilder builder_; ProgramBuilder builder_;
std::string error_; std::string error_;
uint32_t next_id_ = 1; uint32_t next_id_ = 1;