ast: Migrate to using ast::Type

Remove all sem::Type references from the AST.
ConstructedTypes are now all AST types.

The parsers will still create semantic types, but these are now disjoint
and ignored.
The parsers will be updated with future changes to stop creating these
semantic types.

Resolver creates semantic types from the AST types. Most downstream
logic continues to use the semantic types, however transforms will now
need to rebuild AST type information instead of reassigning semantic
information, as semantic nodes are fully rebuilt by the Resolver.

Bug: tint:724
Change-Id: I4ce03a075f13c77648cda5c3691bae202752ecc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49747
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-05-05 09:09:41 +00:00 committed by Commit Bot service account
parent 781de097eb
commit 02ebf0dcae
72 changed files with 1267 additions and 1091 deletions

View File

@ -23,10 +23,10 @@ namespace ast {
BitcastExpression::BitcastExpression(ProgramID program_id,
const Source& source,
typ::Type type,
ast::Type* type,
Expression* expr)
: Base(program_id, source), type_(type), expr_(expr) {
TINT_ASSERT(type_.ast || type_.sem);
TINT_ASSERT(type_);
TINT_ASSERT(expr_);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(expr, program_id);
}
@ -37,7 +37,7 @@ BitcastExpression::~BitcastExpression() = default;
BitcastExpression* BitcastExpression::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source());
auto ty = ctx->Clone(type());
auto* ty = ctx->Clone(type());
auto* e = ctx->Clone(expr_);
return ctx->dst->create<BitcastExpression>(src, ty, e);
}
@ -46,9 +46,8 @@ void BitcastExpression::to_str(const sem::Info& sem,
std::ostream& out,
size_t indent) const {
make_indent(out, indent);
out << "Bitcast[" << result_type_str(sem) << "]<"
<< (type_.ast ? type_.ast->type_name() : type_.sem->type_name()) << ">{"
<< std::endl;
out << "Bitcast[" << result_type_str(sem) << "]<" << type_->type_name()
<< ">{" << std::endl;
expr_->to_str(sem, out, indent + 2);
make_indent(out, indent);
out << "}" << std::endl;

View File

@ -30,14 +30,14 @@ class BitcastExpression : public Castable<BitcastExpression, Expression> {
/// @param expr the expr
BitcastExpression(ProgramID program_id,
const Source& source,
typ::Type type,
ast::Type* type,
Expression* expr);
/// Move constructor
BitcastExpression(BitcastExpression&&);
~BitcastExpression() override;
/// @returns the left side expression
typ::Type type() const { return type_; }
ast::Type* type() const { return type_; }
/// @returns the expression
Expression* expr() const { return expr_; }
@ -58,7 +58,7 @@ class BitcastExpression : public Castable<BitcastExpression, Expression> {
private:
BitcastExpression(const BitcastExpression&) = delete;
typ::Type const type_;
ast::Type* const type_;
Expression* const expr_;
};

View File

@ -27,7 +27,7 @@ TEST_F(BitcastExpressionTest, Create) {
auto* expr = Expr("expr");
auto* exp = create<BitcastExpression>(ty.f32(), expr);
ASSERT_EQ(exp->type(), ty.f32());
EXPECT_TRUE(exp->type()->Is<ast::F32>());
ASSERT_EQ(exp->expr(), expr);
}

View File

@ -27,7 +27,7 @@ Function::Function(ProgramID program_id,
const Source& source,
Symbol symbol,
VariableList params,
typ::Type return_type,
ast::Type* return_type,
BlockStatement* body,
DecorationList decorations,
DecorationList return_type_decorations)
@ -45,7 +45,7 @@ Function::Function(ProgramID program_id,
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(param, program_id);
}
TINT_ASSERT(symbol_.IsValid());
TINT_ASSERT(return_type_.ast || return_type_.sem);
TINT_ASSERT(return_type_);
for (auto* deco : decorations_) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
}
@ -81,7 +81,7 @@ Function* Function::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source());
auto sym = ctx->Clone(symbol());
auto p = ctx->Clone(params_);
auto ret = ctx->Clone(return_type_);
auto* ret = ctx->Clone(return_type_);
auto* b = ctx->Clone(body_);
auto decos = ctx->Clone(decorations_);
auto ret_decos = ctx->Clone(return_type_decorations_);
@ -92,9 +92,7 @@ void Function::to_str(const sem::Info& sem,
std::ostream& out,
size_t indent) const {
make_indent(out, indent);
out << "Function " << symbol_.to_str() << " -> "
<< (return_type_.ast ? return_type_.ast->type_name()
: return_type_.sem->type_name())
out << "Function " << symbol_.to_str() << " -> " << return_type_->type_name()
<< std::endl;
for (auto* deco : decorations()) {
@ -134,7 +132,7 @@ std::string Function::type_name() const {
for (auto* param : params_) {
// No need for the sem::Variable here, functions params must have a
// type
out << param->declared_type()->type_name();
out << param->type()->type_name();
}
return out.str();

View File

@ -49,7 +49,7 @@ class Function : public Castable<Function, Node> {
const Source& source,
Symbol symbol,
VariableList params,
typ::Type return_type,
ast::Type* return_type,
BlockStatement* body,
DecorationList decorations,
DecorationList return_type_decorations);
@ -77,7 +77,7 @@ class Function : public Castable<Function, Node> {
bool IsEntryPoint() const { return pipeline_stage() != PipelineStage::kNone; }
/// @returns the function return type.
typ::Type return_type() const { return return_type_; }
ast::Type* return_type() const { return return_type_; }
/// @returns the decorations attached to the function return type.
const DecorationList& return_type_decorations() const {
@ -115,7 +115,7 @@ class Function : public Castable<Function, Node> {
Symbol const symbol_;
VariableList const params_;
typ::Type const return_type_;
ast::Type* const return_type_;
BlockStatement* const body_;
DecorationList const decorations_;
DecorationList const return_type_decorations_;

View File

@ -32,7 +32,7 @@ TEST_F(FunctionTest, Creation) {
auto* f = Func("func", params, ty.void_(), StatementList{}, DecorationList{});
EXPECT_EQ(f->symbol(), Symbols().Get("func"));
ASSERT_EQ(f->params().size(), 1u);
EXPECT_EQ(f->return_type(), ty.void_());
EXPECT_TRUE(f->return_type()->Is<ast::Void>());
EXPECT_EQ(f->params()[0], var);
}

View File

@ -132,7 +132,7 @@ std::ostream& operator<<(std::ostream& out, const TextureOverloadCase& data) {
return out;
}
typ::Type TextureOverloadCase::resultVectorComponentType(
ast::Type* TextureOverloadCase::resultVectorComponentType(
ProgramBuilder* b) const {
switch (texture_data_type) {
case ast::intrinsic::test::TextureDataType::kF32:
@ -149,7 +149,7 @@ typ::Type TextureOverloadCase::resultVectorComponentType(
ast::Variable* TextureOverloadCase::buildTextureVariable(
ProgramBuilder* b) const {
auto datatype = resultVectorComponentType(b);
auto* datatype = resultVectorComponentType(b);
DecorationList decos = {
b->create<ast::GroupDecoration>(0),
@ -166,8 +166,7 @@ ast::Variable* TextureOverloadCase::buildTextureVariable(
ast::StorageClass::kUniformConstant, nullptr, decos);
case ast::intrinsic::test::TextureKind::kMultisampled:
return b->Global(
"texture",
return b->Global("texture",
b->ty.multisampled_texture(texture_dimension, datatype),
ast::StorageClass::kUniformConstant, nullptr, decos);

View File

@ -209,7 +209,7 @@ struct TextureOverloadCase {
/// @param builder the AST builder used for the test
/// @returns the vector component type of the texture function return value
typ::Type resultVectorComponentType(ProgramBuilder* builder) const;
ast::Type* resultVectorComponentType(ProgramBuilder* builder) const;
/// @param builder the AST builder used for the test
/// @returns a variable holding the test texture, automatically registered as
/// a global variable.

View File

@ -29,14 +29,14 @@ Module::Module(ProgramID program_id, const Source& source)
Module::Module(ProgramID program_id,
const Source& source,
std::vector<Cloneable*> global_decls)
std::vector<ast::Node*> global_decls)
: Base(program_id, source), global_declarations_(std::move(global_decls)) {
for (auto* decl : global_declarations_) {
if (decl == nullptr) {
continue;
}
if (auto* ty = decl->As<sem::Type>()) {
if (auto* ty = decl->As<ast::NamedType>()) {
constructed_types_.push_back(ty);
} else if (auto* func = decl->As<Function>()) {
functions_.push_back(func);
@ -52,16 +52,34 @@ Module::Module(ProgramID program_id,
Module::~Module() = default;
const ast::NamedType* Module::LookupType(Symbol name) const {
for (auto ct : ConstructedTypes()) {
if (auto* ty = ct.ast->As<ast::NamedType>()) {
for (auto* ty : ConstructedTypes()) {
if (ty->name() == name) {
return ty;
}
}
}
return nullptr;
}
void Module::AddGlobalVariable(ast::Variable* var) {
TINT_ASSERT(var);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(var, program_id());
global_variables_.push_back(var);
global_declarations_.push_back(var);
}
void Module::AddConstructedType(ast::NamedType* type) {
TINT_ASSERT(type);
constructed_types_.push_back(type);
global_declarations_.push_back(type);
}
void Module::AddFunction(ast::Function* func) {
TINT_ASSERT(func);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(func, program_id());
functions_.push_back(func);
global_declarations_.push_back(func);
}
Module* Module::Clone(CloneContext* ctx) const {
auto* out = ctx->dst->create<Module>();
out->Copy(ctx, this);
@ -74,7 +92,7 @@ void Module::Copy(CloneContext* ctx, const Module* src) {
TINT_ICE(ctx->dst->Diagnostics()) << "src global declaration was nullptr";
continue;
}
if (auto* ty = decl->As<sem::Type>()) {
if (auto* ty = decl->As<ast::NamedType>()) {
AddConstructedType(ty);
} else if (auto* func = decl->As<Function>()) {
AddFunction(func);
@ -92,16 +110,16 @@ void Module::to_str(const sem::Info& sem,
make_indent(out, indent);
out << "Module{" << std::endl;
indent += 2;
for (auto const ty : constructed_types_) {
for (auto* ty : constructed_types_) {
make_indent(out, indent);
if (auto* alias = ty->As<sem::Alias>()) {
if (auto* alias = ty->As<ast::Alias>()) {
out << alias->symbol().to_str() << " -> " << alias->type()->type_name()
<< std::endl;
if (auto* str = alias->type()->As<sem::StructType>()) {
str->impl()->to_str(sem, out, indent);
if (auto* str = alias->type()->As<ast::Struct>()) {
str->to_str(sem, out, indent);
}
} else if (auto* str = ty->As<sem::StructType>()) {
str->impl()->to_str(sem, out, indent);
} else if (auto* str = ty->As<ast::Struct>()) {
str->to_str(sem, out, indent);
}
}
for (auto* var : global_variables_) {

View File

@ -42,28 +42,23 @@ class Module : public Castable<Module, Node> {
/// the order they were declared in the source program
Module(ProgramID program_id,
const Source& source,
std::vector<Cloneable*> global_decls);
std::vector<ast::Node*> global_decls);
/// Destructor
~Module() override;
/// @returns the ordered global declarations for the translation unit
const std::vector<Cloneable*>& GlobalDeclarations() const {
const std::vector<ast::Node*>& GlobalDeclarations() const {
return global_declarations_;
}
/// Add a global variable to the Builder
/// @param var the variable to add
void AddGlobalVariable(ast::Variable* var) {
TINT_ASSERT(var);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(var, program_id());
global_variables_.push_back(var);
global_declarations_.push_back(var);
}
void AddGlobalVariable(ast::Variable* var);
/// @returns true if the module has the global declaration `decl`
/// @param decl the declaration to check
bool HasGlobalDeclaration(const Cloneable* decl) const {
bool HasGlobalDeclaration(ast::Node* decl) const {
for (auto* d : global_declarations_) {
if (d == decl) {
return true;
@ -79,31 +74,21 @@ class Module : public Castable<Module, Node> {
VariableList& GlobalVariables() { return global_variables_; }
/// Adds a constructed type to the Builder.
/// The type must be an alias or a struct.
/// @param type the constructed type to add
void AddConstructedType(typ::Type type) {
TINT_ASSERT(type);
constructed_types_.push_back(type);
global_declarations_.push_back(const_cast<sem::Type*>(type.sem));
}
void AddConstructedType(ast::NamedType* type);
/// @returns the NamedType registered as a ConstructedType()
/// @param name the name of the type to search for
const ast::NamedType* LookupType(Symbol name) const;
/// @returns the constructed types in the translation unit
const std::vector<typ::Type>& ConstructedTypes() const {
const std::vector<ast::NamedType*>& ConstructedTypes() const {
return constructed_types_;
}
/// Add a function to the Builder
/// @param func the function to add
void AddFunction(ast::Function* func) {
TINT_ASSERT(func);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(func, program_id());
functions_.push_back(func);
global_declarations_.push_back(func);
}
void AddFunction(ast::Function* func);
/// @returns the functions declared in the translation unit
const FunctionList& Functions() const { return functions_; }
@ -132,8 +117,8 @@ class Module : public Castable<Module, Node> {
std::string to_str(const sem::Info& sem) const;
private:
std::vector<Cloneable*> global_declarations_;
std::vector<typ::Type> constructed_types_;
std::vector<ast::Node*> global_declarations_;
std::vector<ast::NamedType*> constructed_types_;
FunctionList functions_;
VariableList global_variables_;
};

View File

@ -24,13 +24,13 @@ namespace ast {
StructMember::StructMember(ProgramID program_id,
const Source& source,
const Symbol& sym,
typ::Type type,
ast::Type* type,
DecorationList decorations)
: Base(program_id, source),
symbol_(sym),
type_(type),
decorations_(std::move(decorations)) {
TINT_ASSERT(type.ast || type.sem);
TINT_ASSERT(type);
TINT_ASSERT(symbol_.IsValid());
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
for (auto* deco : decorations_) {
@ -59,7 +59,7 @@ StructMember* StructMember::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source());
auto sym = ctx->Clone(symbol_);
auto ty = ctx->Clone(type_);
auto* ty = ctx->Clone(type_);
auto decos = ctx->Clone(decorations_);
return ctx->dst->create<StructMember>(src, sym, ty, decos);
}
@ -76,9 +76,7 @@ void StructMember::to_str(const sem::Info& sem,
out << "]] ";
}
out << symbol_.to_str() << ": "
<< (type_.ast ? type_.ast->type_name() : type_.sem->type_name()) << "}"
<< std::endl;
out << symbol_.to_str() << ": " << type_->type_name() << "}" << std::endl;
}
} // namespace ast

View File

@ -36,7 +36,7 @@ class StructMember : public Castable<StructMember, Node> {
StructMember(ProgramID program_id,
const Source& source,
const Symbol& sym,
typ::Type type,
ast::Type* type,
DecorationList decorations);
/// Move constructor
StructMember(StructMember&&);
@ -47,7 +47,7 @@ class StructMember : public Castable<StructMember, Node> {
const Symbol& symbol() const { return symbol_; }
/// @returns the type
typ::Type type() const { return type_; }
ast::Type* type() const { return type_; }
/// @returns the decorations
const DecorationList& decorations() const { return decorations_; }
@ -75,7 +75,7 @@ class StructMember : public Castable<StructMember, Node> {
StructMember(const StructMember&) = delete;
Symbol const symbol_;
typ::Type const type_;
ast::Type* const type_;
DecorationList const decorations_;
};

View File

@ -24,7 +24,7 @@ using StructMemberTest = TestHelper;
TEST_F(StructMemberTest, Creation) {
auto* st = Member("a", ty.i32(), {MemberSize(4)});
EXPECT_EQ(st->symbol(), Symbol(1, ID()));
EXPECT_EQ(st->type(), ty.i32());
EXPECT_TRUE(st->type()->Is<ast::I32>());
EXPECT_EQ(st->decorations().size(), 1u);
EXPECT_TRUE(st->decorations()[0]->Is<StructMemberSizeDecoration>());
EXPECT_EQ(st->source().range.begin.line, 0u);
@ -38,7 +38,7 @@ TEST_F(StructMemberTest, CreationWithSource) {
Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 8}}},
"a", ty.i32());
EXPECT_EQ(st->symbol(), Symbol(1, ID()));
EXPECT_EQ(st->type(), ty.i32());
EXPECT_TRUE(st->type()->Is<ast::I32>());
EXPECT_EQ(st->decorations().size(), 0u);
EXPECT_EQ(st->source().range.begin.line, 27u);
EXPECT_EQ(st->source().range.begin.column, 4u);

View File

@ -23,10 +23,10 @@ namespace ast {
TypeConstructorExpression::TypeConstructorExpression(ProgramID program_id,
const Source& source,
typ::Type type,
ast::Type* type,
ExpressionList values)
: Base(program_id, source), type_(type), values_(std::move(values)) {
TINT_ASSERT(type_.ast || type_.sem);
TINT_ASSERT(type_);
for (auto* val : values_) {
TINT_ASSERT(val);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(val, program_id);
@ -42,7 +42,7 @@ TypeConstructorExpression* TypeConstructorExpression::Clone(
CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source());
auto ty = ctx->Clone(type());
auto* ty = ctx->Clone(type());
auto vals = ctx->Clone(values());
return ctx->dst->create<TypeConstructorExpression>(src, ty, vals);
}
@ -53,8 +53,7 @@ void TypeConstructorExpression::to_str(const sem::Info& sem,
make_indent(out, indent);
out << "TypeConstructor[" << result_type_str(sem) << "]{" << std::endl;
make_indent(out, indent + 2);
out << (type_.ast ? type_.ast->type_name() : type_.sem->type_name())
<< std::endl;
out << type_->type_name() << std::endl;
for (auto* val : values_) {
val->to_str(sem, out, indent + 2);

View File

@ -33,14 +33,14 @@ class TypeConstructorExpression
/// @param values the constructor values
TypeConstructorExpression(ProgramID program_id,
const Source& source,
typ::Type type,
ast::Type* type,
ExpressionList values);
/// Move constructor
TypeConstructorExpression(TypeConstructorExpression&&);
~TypeConstructorExpression() override;
/// @returns the type
typ::Type type() const { return type_; }
ast::Type* type() const { return type_; }
/// @returns the values
const ExpressionList& values() const { return values_; }
@ -62,7 +62,7 @@ class TypeConstructorExpression
private:
TypeConstructorExpression(const TypeConstructorExpression&) = delete;
typ::Type const type_;
ast::Type* const type_;
ExpressionList const values_;
};

View File

@ -26,7 +26,7 @@ TEST_F(TypeConstructorExpressionTest, Creation) {
expr.push_back(Expr("expr"));
auto* t = create<TypeConstructorExpression>(ty.f32(), expr);
EXPECT_EQ(t->type(), ty.f32());
EXPECT_TRUE(t->type()->Is<ast::F32>());
ASSERT_EQ(t->values().size(), 1u);
EXPECT_EQ(t->values()[0], expr[0]);
}

View File

@ -27,7 +27,7 @@ Variable::Variable(ProgramID program_id,
const Source& source,
const Symbol& sym,
StorageClass declared_storage_class,
const typ::Type type,
ast::Type* type,
bool is_const,
Expression* constructor,
DecorationList decorations)
@ -41,7 +41,7 @@ Variable::Variable(ProgramID program_id,
TINT_ASSERT(symbol_.IsValid());
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
// no type means we must have a constructor to infer it
TINT_ASSERT(type_.ast || type_.sem || constructor);
TINT_ASSERT(type_ || constructor);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(constructor, program_id);
}
@ -73,7 +73,7 @@ uint32_t Variable::constant_id() const {
Variable* Variable::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source());
auto sym = ctx->Clone(symbol());
auto ty = ctx->Clone(type());
auto* ty = ctx->Clone(type());
auto* ctor = ctx->Clone(constructor());
auto decos = ctx->Clone(decorations());
return ctx->dst->create<Variable>(src, sym, declared_storage_class(), ty,
@ -90,8 +90,7 @@ void Variable::info_to_str(const sem::Info& sem,
out << (var_sem ? var_sem->StorageClass() : declared_storage_class())
<< std::endl;
make_indent(out, indent);
out << (type_.sem ? type_.sem->type_name() : type_.ast->type_name())
<< std::endl;
out << type_->type_name() << std::endl;
}
void Variable::constructor_to_str(const sem::Info& sem,

View File

@ -109,7 +109,7 @@ class Variable : public Castable<Variable, Node> {
const Source& source,
const Symbol& sym,
StorageClass declared_storage_class,
typ::Type type,
ast::Type* type,
bool is_const,
Expression* constructor,
DecorationList decorations);
@ -121,12 +121,8 @@ class Variable : public Castable<Variable, Node> {
/// @returns the variable symbol
const Symbol& symbol() const { return symbol_; }
/// @returns the declared type
// TODO(crbug.com/tint/697): Remove and use type() instead
sem::Type* declared_type() const { return const_cast<sem::Type*>(type_.sem); }
/// @returns the variable type
typ::Type type() const { return type_; }
ast::Type* type() const { return type_; }
/// @returns the declared storage class
StorageClass declared_storage_class() const {
@ -185,7 +181,7 @@ class Variable : public Castable<Variable, Node> {
Symbol const symbol_;
// The value type if a const or formal paramter, and the store type if a var
typ::Type const type_;
ast::Type* const type_;
bool const is_const_;
Expression* const constructor_;
DecorationList const decorations_;

View File

@ -27,7 +27,7 @@ TEST_F(VariableTest, Creation) {
EXPECT_EQ(v->symbol(), Symbol(1, ID()));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kFunction);
EXPECT_EQ(v->declared_type(), ty.i32());
EXPECT_TRUE(v->type()->Is<ast::I32>());
EXPECT_EQ(v->source().range.begin.line, 0u);
EXPECT_EQ(v->source().range.begin.column, 0u);
EXPECT_EQ(v->source().range.end.line, 0u);
@ -41,7 +41,7 @@ TEST_F(VariableTest, CreationWithSource) {
EXPECT_EQ(v->symbol(), Symbol(1, ID()));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kPrivate);
EXPECT_EQ(v->declared_type(), ty.f32());
EXPECT_TRUE(v->type()->Is<ast::F32>());
EXPECT_EQ(v->source().range.begin.line, 27u);
EXPECT_EQ(v->source().range.begin.column, 4u);
EXPECT_EQ(v->source().range.end.line, 27u);
@ -55,7 +55,7 @@ TEST_F(VariableTest, CreationEmpty) {
EXPECT_EQ(v->symbol(), Symbol(1, ID()));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kWorkgroup);
EXPECT_EQ(v->declared_type(), ty.i32());
EXPECT_TRUE(v->type()->Is<ast::I32>());
EXPECT_EQ(v->source().range.begin.line, 27u);
EXPECT_EQ(v->source().range.begin.column, 4u);
EXPECT_EQ(v->source().range.end.line, 27u);

View File

@ -190,12 +190,12 @@ class CloneContext {
return CheckedCast<T>(c);
}
/// Clones the type pair
/// Clones the AST node of the type pair
/// @param tp the type pair to clone
/// @return the cloned type pair
/// @return the cloned AST node wrapped in a type pair
template <typename AST, typename SEM>
typ::TypePair<AST, SEM> Clone(const typ::TypePair<AST, SEM>& tp) {
return Clone(const_cast<sem::Type*>(tp.sem));
return Clone(const_cast<ast::Type*>(tp.ast));
}
/// Clones the Source `s` into #dst

View File

@ -194,6 +194,8 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
continue;
}
auto* sem = program_->Sem().Get(func);
EntryPoint entry_point;
entry_point.name = program_->Symbols().NameFor(func->symbol());
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol());
@ -201,20 +203,21 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
std::tie(entry_point.workgroup_size_x, entry_point.workgroup_size_y,
entry_point.workgroup_size_z) = func->workgroup_size();
for (auto* param : func->params()) {
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->symbol()),
param->declared_type(), param->decorations(),
for (auto* param : sem->Parameters()) {
AddEntryPointInOutVariables(
program_->Symbols().NameFor(param->Declaration()->symbol()),
param->Type(), param->Declaration()->decorations(),
entry_point.input_variables);
}
if (!func->return_type()->Is<sem::Void>()) {
AddEntryPointInOutVariables("<retval>", func->return_type(),
if (!sem->ReturnType()->Is<sem::Void>()) {
AddEntryPointInOutVariables("<retval>", sem->ReturnType(),
func->return_type_decorations(),
entry_point.output_variables);
}
// TODO(crbug.com/tint/697): Remove this.
for (auto* var : program_->Sem().Get(func)->ReferencedModuleVariables()) {
for (auto* var : sem->ReferencedModuleVariables()) {
auto* decl = var->Declaration();
auto name = program_->Symbols().NameFor(decl->symbol());
@ -553,10 +556,12 @@ void Inspector::AddEntryPointInOutVariables(
if (auto* struct_ty = unwrapped_type->As<sem::StructType>()) {
// Recurse into members.
for (auto* member : struct_ty->impl()->members()) {
auto* sem = program_->Sem().Get(struct_ty);
for (auto* member : sem->Members()) {
AddEntryPointInOutVariables(
name + "." + program_->Symbols().NameFor(member->symbol()),
member->type(), member->decorations(), variables);
name + "." +
program_->Symbols().NameFor(member->Declaration()->symbol()),
member->Type(), member->Declaration()->decorations(), variables);
}
return;
}

View File

@ -144,7 +144,7 @@ class InspectorHelper : public ProgramBuilder {
/// @param val value to initialize the variable with, if NULL no initializer
/// will be added.
template <class T>
void AddConstantID(std::string name, uint32_t id, typ::Type type, T* val) {
void AddConstantID(std::string name, uint32_t id, ast::Type* type, T* val) {
ast::Expression* constructor = nullptr;
if (val) {
constructor = Expr(*val);
@ -172,9 +172,8 @@ class InspectorHelper : public ProgramBuilder {
/// @param idx index of member
/// @param type type of member
/// @returns a string for the member
std::string StructMemberName(size_t idx, typ::Type type) {
return std::to_string(idx) +
(type.sem ? type.sem->type_name() : type.ast->type_name());
std::string StructMemberName(size_t idx, ast::Type* type) {
return std::to_string(idx) + type->type_name();
}
/// Generates a struct type
@ -182,11 +181,11 @@ class InspectorHelper : public ProgramBuilder {
/// @param member_types a vector of member types
/// @param is_block whether or not to decorate as a Block
/// @returns a struct type
typ::Struct MakeStructType(const std::string& name,
std::vector<typ::Type> member_types,
ast::Struct* MakeStructType(const std::string& name,
std::vector<ast::Type*> member_types,
bool is_block) {
ast::StructMemberList members;
for (auto type : member_types) {
for (auto* type : member_types) {
members.push_back(Member(StructMemberName(members.size(), type), type));
}
@ -202,37 +201,37 @@ class InspectorHelper : public ProgramBuilder {
/// @param name name for the type
/// @param member_types a vector of member types
/// @returns a struct type that has the layout for an uniform buffer.
typ::Struct MakeUniformBufferType(const std::string& name,
std::vector<typ::Type> member_types) {
ast::Struct* MakeUniformBufferType(const std::string& name,
std::vector<ast::Type*> member_types) {
return MakeStructType(name, member_types, true);
}
/// Generates types appropriate for using in a storage buffer
/// @param name name for the type
/// @param member_types a vector of member types
/// @returns a tuple {struct type, access control type}, where the struct has
/// the layout for a storage buffer, and the control type wraps the
/// struct.
std::tuple<typ::Struct, typ::AccessControl> MakeStorageBufferTypes(
/// @returns a function that returns an ast::AccessControl to the created
/// structure.
std::function<ast::AccessControl*()> MakeStorageBufferTypes(
const std::string& name,
std::vector<typ::Type> member_types) {
auto struct_type = MakeStructType(name, member_types, true);
auto access_type = ty.access(ast::AccessControl::kReadWrite, struct_type);
return {struct_type, std::move(access_type)};
std::vector<ast::Type*> member_types) {
MakeStructType(name, member_types, true);
return [this, name] {
return ty.access(ast::AccessControl::kReadWrite, ty.type_name(name));
};
}
/// Generates types appropriate for using in a read-only storage buffer
/// @param name name for the type
/// @param member_types a vector of member types
/// @returns a tuple {struct type, access control type}, where the struct has
/// the layout for a read-only storage buffer, and the control type
/// wraps the struct.
std::tuple<typ::Struct, typ::AccessControl> MakeReadOnlyStorageBufferTypes(
/// @returns a function that returns an ast::AccessControl to the created
/// structure.
std::function<ast::AccessControl*()> MakeReadOnlyStorageBufferTypes(
const std::string& name,
std::vector<typ::Type> member_types) {
auto struct_type = MakeStructType(name, member_types, true);
auto access_type = ty.access(ast::AccessControl::kReadOnly, struct_type);
return {struct_type, std::move(access_type)};
std::vector<ast::Type*> member_types) {
MakeStructType(name, member_types, true);
return [this, name] {
return ty.access(ast::AccessControl::kReadOnly, ty.type_name(name));
};
}
/// Adds a binding variable with a struct type to the program
@ -242,7 +241,7 @@ class InspectorHelper : public ProgramBuilder {
/// @param group the binding and group to use for the uniform buffer
/// @param binding the binding number to use for the uniform buffer
void AddBinding(const std::string& name,
typ::Type type,
ast::Type* type,
ast::StorageClass storage_class,
uint32_t group,
uint32_t binding) {
@ -259,7 +258,7 @@ class InspectorHelper : public ProgramBuilder {
/// @param group the binding/group/ to use for the uniform buffer
/// @param binding the binding number to use for the uniform buffer
void AddUniformBuffer(const std::string& name,
typ::Type type,
ast::Type* type,
uint32_t group,
uint32_t binding) {
AddBinding(name, type, ast::StorageClass::kUniform, group, binding);
@ -271,7 +270,7 @@ class InspectorHelper : public ProgramBuilder {
/// @param group the binding/group to use for the storage buffer
/// @param binding the binding number to use for the storage buffer
void AddStorageBuffer(const std::string& name,
typ::Type type,
ast::Type* type,
uint32_t group,
uint32_t binding) {
AddBinding(name, type, ast::StorageClass::kStorage, group, binding);
@ -284,11 +283,11 @@ class InspectorHelper : public ProgramBuilder {
void MakeStructVariableReferenceBodyFunction(
std::string func_name,
std::string struct_name,
std::vector<std::tuple<size_t, typ::Type>> members) {
std::vector<std::tuple<size_t, ast::Type*>> members) {
ast::StatementList stmts;
for (auto member : members) {
size_t member_idx;
typ::Type member_type;
ast::Type* member_type;
std::tie(member_idx, member_type) = member;
std::string member_name = StructMemberName(member_idx, member_type);
@ -298,7 +297,7 @@ class InspectorHelper : public ProgramBuilder {
for (auto member : members) {
size_t member_idx;
typ::Type member_type;
ast::Type* member_type;
std::tie(member_idx, member_type) = member;
std::string member_name = StructMemberName(member_idx, member_type);
@ -337,7 +336,7 @@ class InspectorHelper : public ProgramBuilder {
/// @param type the data type of the sampled texture
/// @returns the generated SampleTextureType
typ::SampledTexture MakeSampledTextureType(ast::TextureDimension dim,
typ::Type type) {
ast::Type* type) {
return ty.sampled_texture(dim, type);
}
@ -354,7 +353,7 @@ class InspectorHelper : public ProgramBuilder {
/// @returns the generated SampleTextureType
typ::MultisampledTexture MakeMultisampledTextureType(
ast::TextureDimension dim,
typ::Type type) {
ast::Type* type) {
return ty.multisampled_texture(dim, type);
}
@ -364,7 +363,7 @@ class InspectorHelper : public ProgramBuilder {
/// @param group the binding/group to use for the sampled texture
/// @param binding the binding number to use for the sampled texture
void AddSampledTexture(const std::string& name,
typ::Type type,
ast::Type* type,
uint32_t group,
uint32_t binding) {
AddBinding(name, type, ast::StorageClass::kUniformConstant, group, binding);
@ -376,13 +375,13 @@ class InspectorHelper : public ProgramBuilder {
/// @param group the binding/group to use for the multi-sampled texture
/// @param binding the binding number to use for the multi-sampled texture
void AddMultisampledTexture(const std::string& name,
typ::Type type,
ast::Type* type,
uint32_t group,
uint32_t binding) {
AddBinding(name, type, ast::StorageClass::kUniformConstant, group, binding);
}
void AddGlobalVariable(const std::string& name, typ::Type type) {
void AddGlobalVariable(const std::string& name, ast::Type* type) {
Global(name, type, ast::StorageClass::kUniformConstant);
}
@ -392,7 +391,7 @@ class InspectorHelper : public ProgramBuilder {
/// @param group the binding/group to use for the depth texture
/// @param binding the binding number to use for the depth texture
void AddDepthTexture(const std::string& name,
typ::Type type,
ast::Type* type,
uint32_t group,
uint32_t binding) {
AddBinding(name, type, ast::StorageClass::kUniformConstant, group, binding);
@ -411,7 +410,7 @@ class InspectorHelper : public ProgramBuilder {
const std::string& texture_name,
const std::string& sampler_name,
const std::string& coords_name,
typ::Type base_type,
ast::Type* base_type,
ast::DecorationList decorations) {
std::string result_name = "sampler_result";
@ -442,7 +441,7 @@ class InspectorHelper : public ProgramBuilder {
const std::string& sampler_name,
const std::string& coords_name,
const std::string& array_index,
typ::Type base_type,
ast::Type* base_type,
ast::DecorationList decorations) {
std::string result_name = "sampler_result";
@ -475,7 +474,7 @@ class InspectorHelper : public ProgramBuilder {
const std::string& sampler_name,
const std::string& coords_name,
const std::string& depth_name,
typ::Type base_type,
ast::Type* base_type,
ast::DecorationList decorations) {
std::string result_name = "sampler_result";
@ -494,7 +493,7 @@ class InspectorHelper : public ProgramBuilder {
/// Gets an appropriate type for the data in a given texture type.
/// @param sampled_kind type of in the texture
/// @returns a pointer to a type appropriate for the coord param
typ::Type GetBaseType(ResourceBinding::SampledKind sampled_kind) {
ast::Type* GetBaseType(ResourceBinding::SampledKind sampled_kind) {
switch (sampled_kind) {
case ResourceBinding::SampledKind::kFloat:
return ty.f32();
@ -512,17 +511,17 @@ class InspectorHelper : public ProgramBuilder {
/// @param dim dimensionality of the texture being sampled
/// @param scalar the scalar type
/// @returns a pointer to a type appropriate for the coord param
typ::Type GetCoordsType(ast::TextureDimension dim, typ::Type scalar) {
ast::Type* GetCoordsType(ast::TextureDimension dim, ast::Type* scalar) {
switch (dim) {
case ast::TextureDimension::k1d:
return scalar;
case ast::TextureDimension::k2d:
case ast::TextureDimension::k2dArray:
return create<sem::Vector>(scalar, 2);
return create<ast::Vector>(scalar, 2);
case ast::TextureDimension::k3d:
case ast::TextureDimension::kCube:
case ast::TextureDimension::kCubeArray:
return create<sem::Vector>(scalar, 3);
return create<ast::Vector>(scalar, 3);
default:
[=]() { FAIL() << "Unsupported texture dimension: " << dim; }();
}
@ -604,8 +603,10 @@ class InspectorHelper : public ProgramBuilder {
return *inspector_;
}
typ::Sampler sampler_type() { return ty.sampler(ast::SamplerKind::kSampler); }
typ::Sampler comparison_sampler_type() {
ast::Sampler* sampler_type() {
return ty.sampler(ast::SamplerKind::kSampler);
}
ast::Sampler* comparison_sampler_type() {
return ty.sampler(ast::SamplerKind::kComparisonSampler);
}
@ -835,23 +836,23 @@ TEST_F(InspectorGetEntryPointTest, NoInOutVariables) {
TEST_P(InspectorGetEntryPointTestWithComponentTypeParam, InOutVariables) {
ComponentType inspector_type = GetParam();
typ::Type tint_type = nullptr;
std::function<typ::Type()> tint_type;
switch (inspector_type) {
case ComponentType::kFloat:
tint_type = ty.f32();
tint_type = [this]() -> typ::Type { return ty.f32(); };
break;
case ComponentType::kSInt:
tint_type = ty.i32();
tint_type = [this]() -> typ::Type { return ty.i32(); };
break;
case ComponentType::kUInt:
tint_type = ty.u32();
tint_type = [this]() -> typ::Type { return ty.u32(); };
break;
case ComponentType::kUnknown:
return;
}
auto* in_var = Param("in_var", tint_type, {Location(0u)});
Func("foo", {in_var}, tint_type, {Return("in_var")},
auto* in_var = Param("in_var", tint_type(), {Location(0u)});
Func("foo", {in_var}, tint_type(), {Return("in_var")},
{Stage(ast::PipelineStage::kFragment)}, {Location(0u)});
Inspector& inspector = Build();
@ -1600,18 +1601,12 @@ TEST_F(InspectorGetResourceBindingsTest, Simple) {
AddUniformBuffer("ub_var", ub_struct_type, 0, 0);
MakeStructVariableReferenceBodyFunction("ub_func", "ub_var", {{0, ty.i32()}});
typ::Struct sb_struct_type;
typ::AccessControl sb_control_type;
std::tie(sb_struct_type, sb_control_type) =
MakeStorageBufferTypes("sb_type", {ty.i32()});
AddStorageBuffer("sb_var", sb_control_type, 1, 0);
auto sb = MakeStorageBufferTypes("sb_type", {ty.i32()});
AddStorageBuffer("sb_var", sb(), 1, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "sb_var", {{0, ty.i32()}});
typ::Struct rosb_struct_type;
typ::AccessControl rosb_control_type;
std::tie(rosb_struct_type, rosb_control_type) =
MakeReadOnlyStorageBufferTypes("rosb_type", {ty.i32()});
AddStorageBuffer("rosb_var", rosb_control_type, 1, 1);
auto ro_sb = MakeReadOnlyStorageBufferTypes("rosb_type", {ty.i32()});
AddStorageBuffer("rosb_var", ro_sb(), 1, 1);
MakeStructVariableReferenceBodyFunction("rosb_func", "rosb_var",
{{0, ty.i32()}});
@ -1896,11 +1891,8 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, ContainingArray) {
}
TEST_F(InspectorGetStorageBufferResourceBindingsTest, Simple) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type = MakeStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -1924,11 +1916,12 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, Simple) {
}
TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleMembers) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeStorageBufferTypes("foo_type", {ty.i32(), ty.u32(), ty.f32()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type = MakeStorageBufferTypes("foo_type", {
ty.i32(),
ty.u32(),
ty.f32(),
});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction(
"sb_func", "foo_sb", {{0, ty.i32()}, {1, ty.u32()}, {2, ty.f32()}});
@ -1953,13 +1946,14 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleMembers) {
}
TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleStorageBuffers) {
typ::Struct sb_struct_type;
typ::AccessControl sb_control_type;
std::tie(sb_struct_type, sb_control_type) =
MakeStorageBufferTypes("sb_type", {ty.i32(), ty.u32(), ty.f32()});
AddStorageBuffer("sb_foo", sb_control_type, 0, 0);
AddStorageBuffer("sb_bar", sb_control_type, 0, 1);
AddStorageBuffer("sb_baz", sb_control_type, 2, 0);
auto sb_struct_type = MakeStorageBufferTypes("sb_type", {
ty.i32(),
ty.u32(),
ty.f32(),
});
AddStorageBuffer("sb_foo", sb_struct_type(), 0, 0);
AddStorageBuffer("sb_bar", sb_struct_type(), 0, 1);
AddStorageBuffer("sb_baz", sb_struct_type(), 2, 0);
auto AddReferenceFunc = [this](const std::string& func_name,
const std::string& var_name) {
@ -2014,11 +2008,9 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleStorageBuffers) {
}
TEST_F(InspectorGetStorageBufferResourceBindingsTest, ContainingArray) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
auto foo_struct_type =
MakeStorageBufferTypes("foo_type", {ty.i32(), ty.array<u32, 4>()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -2042,11 +2034,11 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, ContainingArray) {
}
TEST_F(InspectorGetStorageBufferResourceBindingsTest, ContainingRuntimeArray) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeStorageBufferTypes("foo_type", {ty.i32(), ty.array<u32>()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type = MakeStorageBufferTypes("foo_type", {
ty.i32(),
ty.array<u32>(),
});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -2070,11 +2062,8 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, ContainingRuntimeArray) {
}
TEST_F(InspectorGetStorageBufferResourceBindingsTest, ContainingPadding) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeStorageBufferTypes("foo_type", {ty.vec3<f32>()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type = MakeStorageBufferTypes("foo_type", {ty.vec3<f32>()});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb",
{{0, ty.vec3<f32>()}});
@ -2099,11 +2088,8 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, ContainingPadding) {
}
TEST_F(InspectorGetStorageBufferResourceBindingsTest, SkipReadOnly) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeReadOnlyStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type = MakeReadOnlyStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -2120,11 +2106,8 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, SkipReadOnly) {
}
TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest, Simple) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeReadOnlyStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type = MakeReadOnlyStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -2149,13 +2132,14 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest, Simple) {
TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest,
MultipleStorageBuffers) {
typ::Struct sb_struct_type;
typ::AccessControl sb_control_type;
std::tie(sb_struct_type, sb_control_type) =
MakeReadOnlyStorageBufferTypes("sb_type", {ty.i32(), ty.u32(), ty.f32()});
AddStorageBuffer("sb_foo", sb_control_type, 0, 0);
AddStorageBuffer("sb_bar", sb_control_type, 0, 1);
AddStorageBuffer("sb_baz", sb_control_type, 2, 0);
auto sb_struct_type = MakeReadOnlyStorageBufferTypes("sb_type", {
ty.i32(),
ty.u32(),
ty.f32(),
});
AddStorageBuffer("sb_foo", sb_struct_type(), 0, 0);
AddStorageBuffer("sb_bar", sb_struct_type(), 0, 1);
AddStorageBuffer("sb_baz", sb_struct_type(), 2, 0);
auto AddReferenceFunc = [this](const std::string& func_name,
const std::string& var_name) {
@ -2210,11 +2194,12 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest,
}
TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest, ContainingArray) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) = MakeReadOnlyStorageBufferTypes(
"foo_type", {ty.i32(), ty.array<u32, 4>()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type =
MakeReadOnlyStorageBufferTypes("foo_type", {
ty.i32(),
ty.array<u32, 4>(),
});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -2239,11 +2224,12 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest, ContainingArray) {
TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest,
ContainingRuntimeArray) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeReadOnlyStorageBufferTypes("foo_type", {ty.i32(), ty.array<u32>()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type =
MakeReadOnlyStorageBufferTypes("foo_type", {
ty.i32(),
ty.array<u32>(),
});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -2267,11 +2253,8 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest,
}
TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest, SkipNonReadOnly) {
typ::Struct foo_struct_type;
typ::AccessControl foo_control_type;
std::tie(foo_struct_type, foo_control_type) =
MakeStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_control_type, 0, 0);
auto foo_struct_type = MakeStorageBufferTypes("foo_type", {ty.i32()});
AddStorageBuffer("foo_sb", foo_struct_type(), 0, 0);
MakeStructVariableReferenceBodyFunction("sb_func", "foo_sb", {{0, ty.i32()}});
@ -2514,7 +2497,7 @@ TEST_P(InspectorGetSampledTextureResourceBindingsTestWithParam, textureSample) {
GetParam().type_dim, GetBaseType(GetParam().sampled_kind));
AddSampledTexture("foo_texture", sampled_texture_type, 0, 0);
AddSampler("foo_sampler", 0, 1);
auto coord_type = GetCoordsType(GetParam().type_dim, ty.f32());
auto* coord_type = GetCoordsType(GetParam().type_dim, ty.f32());
AddGlobalVariable("foo_coords", coord_type);
MakeSamplerReferenceBodyFunction("ep", "foo_texture", "foo_sampler",
@ -2572,7 +2555,7 @@ TEST_P(InspectorGetSampledArrayTextureResourceBindingsTestWithParam,
GetParam().type_dim, GetBaseType(GetParam().sampled_kind));
AddSampledTexture("foo_texture", sampled_texture_type, 0, 0);
AddSampler("foo_sampler", 0, 1);
auto coord_type = GetCoordsType(GetParam().type_dim, ty.f32());
auto* coord_type = GetCoordsType(GetParam().type_dim, ty.f32());
AddGlobalVariable("foo_coords", coord_type);
AddGlobalVariable("foo_array_index", ty.i32());
@ -2615,7 +2598,7 @@ TEST_P(InspectorGetMultisampledTextureResourceBindingsTestWithParam,
auto multisampled_texture_type = MakeMultisampledTextureType(
GetParam().type_dim, GetBaseType(GetParam().sampled_kind));
AddMultisampledTexture("foo_texture", multisampled_texture_type, 0, 0);
auto coord_type = GetCoordsType(GetParam().type_dim, ty.i32());
auto* coord_type = GetCoordsType(GetParam().type_dim, ty.i32());
AddGlobalVariable("foo_coords", coord_type);
AddGlobalVariable("foo_sample_index", ty.i32());
@ -2633,9 +2616,9 @@ TEST_P(InspectorGetMultisampledTextureResourceBindingsTestWithParam,
auto result = inspector.GetMultisampledTextureResourceBindings("ep");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
ASSERT_EQ(1u, result.size());
EXPECT_EQ(ResourceBinding::ResourceType::kMultisampledTexture,
result[0].resource_type);
ASSERT_EQ(1u, result.size());
EXPECT_EQ(0u, result[0].bind_group);
EXPECT_EQ(0u, result[0].binding);
EXPECT_EQ(GetParam().inspector_dim, result[0].dim);
@ -2685,7 +2668,7 @@ TEST_P(InspectorGetMultisampledArrayTextureResourceBindingsTestWithParam,
GetParam().type_dim, GetBaseType(GetParam().sampled_kind));
AddMultisampledTexture("foo_texture", multisampled_texture_type, 0, 0);
AddSampler("foo_sampler", 0, 1);
auto coord_type = GetCoordsType(GetParam().type_dim, ty.f32());
auto* coord_type = GetCoordsType(GetParam().type_dim, ty.f32());
AddGlobalVariable("foo_coords", coord_type);
AddGlobalVariable("foo_array_index", ty.i32());

View File

@ -16,7 +16,6 @@
#include "src/ast/assignment_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/type_name.h"
#include "src/ast/variable_decl_statement.h"
#include "src/debug.h"
#include "src/demangler.h"
@ -96,51 +95,62 @@ const sem::Type* ProgramBuilder::TypeOf(const ast::Type* type) const {
}
ast::ConstructorExpression* ProgramBuilder::ConstructValueFilledWith(
typ::Type type,
const ast::Type* type,
int elem_value) {
auto* unwrapped_type = type->UnwrapAliasIfNeeded();
if (unwrapped_type->Is<sem::Bool>()) {
CloneContext ctx(this);
if (type->Is<ast::Bool>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::BoolLiteral>(elem_value == 0 ? false : true));
}
if (unwrapped_type->Is<sem::I32>()) {
if (type->Is<ast::I32>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::SintLiteral>(static_cast<i32>(elem_value)));
}
if (unwrapped_type->Is<sem::U32>()) {
if (type->Is<ast::U32>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::UintLiteral>(static_cast<u32>(elem_value)));
}
if (unwrapped_type->Is<sem::F32>()) {
if (type->Is<ast::F32>()) {
return create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(static_cast<f32>(elem_value)));
}
if (auto* v = unwrapped_type->As<sem::Vector>()) {
if (auto* v = type->As<ast::Vector>()) {
ast::ExpressionList el(v->size());
for (size_t i = 0; i < el.size(); i++) {
el[i] = ConstructValueFilledWith(v->type(), elem_value);
el[i] = ConstructValueFilledWith(ctx.Clone(v->type()), elem_value);
}
return create<ast::TypeConstructorExpression>(type, std::move(el));
return create<ast::TypeConstructorExpression>(const_cast<ast::Type*>(type),
std::move(el));
}
if (auto* m = unwrapped_type->As<sem::Matrix>()) {
auto* col_vec_type = create<sem::Vector>(m->type(), m->rows());
ast::ExpressionList el(col_vec_type->size());
if (auto* m = type->As<ast::Matrix>()) {
ast::ExpressionList el(m->columns());
for (size_t i = 0; i < el.size(); i++) {
auto* col_vec_type = create<ast::Vector>(ctx.Clone(m->type()), m->rows());
el[i] = ConstructValueFilledWith(col_vec_type, elem_value);
}
return create<ast::TypeConstructorExpression>(type, std::move(el));
return create<ast::TypeConstructorExpression>(const_cast<ast::Type*>(type),
std::move(el));
}
TINT_ASSERT(false);
if (auto* tn = type->As<ast::TypeName>()) {
if (auto* lookup = AST().LookupType(tn->name())) {
if (auto* alias = lookup->As<ast::Alias>()) {
return ConstructValueFilledWith(ctx.Clone(alias->type()), elem_value);
}
}
TINT_ICE(diagnostics_) << "unable to find NamedType '"
<< Symbols().NameFor(tn->name()) << "'";
return nullptr;
}
TINT_ICE(diagnostics_) << "unhandled type: " << type->TypeInfo().name;
return nullptr;
}
typ::Type ProgramBuilder::TypesBuilder::MaybeCreateTypename(
typ::Type type) const {
if (auto* alias = As<ast::Alias>(type.ast)) {
return {builder->create<ast::TypeName>(alias->symbol()), type.sem};
}
if (auto* str = As<ast::Struct>(type.ast)) {
return {builder->create<ast::TypeName>(str->name()), type.sem};
if (auto* nt = As<ast::NamedType>(type.ast)) {
return {type_name(nt->name()), type.sem};
}
return type;
}

View File

@ -1060,7 +1060,7 @@ class ProgramBuilder {
/// @return an `ast::TypeConstructorExpression` of `type` constructed with the
/// values `args`.
template <typename... ARGS>
ast::TypeConstructorExpression* Construct(typ::Type type, ARGS&&... args) {
ast::TypeConstructorExpression* Construct(ast::Type* type, ARGS&&... args) {
type = ty.MaybeCreateTypename(type);
return create<ast::TypeConstructorExpression>(
type, ExprList(std::forward<ARGS>(args)...));
@ -1074,7 +1074,7 @@ class ProgramBuilder {
/// @param elem_value the initial or element value (for vec and mat) to
/// construct with
/// @return the constructor expression
ast::ConstructorExpression* ConstructValueFilledWith(typ::Type type,
ast::ConstructorExpression* ConstructValueFilledWith(const ast::Type* type,
int elem_value = 0);
/// @param args the arguments for the vector constructor
@ -1187,7 +1187,7 @@ class ProgramBuilder {
/// @return an `ast::TypeConstructorExpression` of an array with element type
/// `subtype`, constructed with the values `args`.
template <typename... ARGS>
ast::TypeConstructorExpression* array(typ::Type subtype,
ast::TypeConstructorExpression* array(ast::Type* subtype,
uint32_t n,
ARGS&&... args) {
return Construct(ty.array(subtype, n), std::forward<ARGS>(args)...);
@ -1201,7 +1201,7 @@ class ProgramBuilder {
/// @returns a `ast::Variable` with the given name, storage and type
template <typename NAME>
ast::Variable* Var(NAME&& name,
typ::Type type,
ast::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
@ -1220,7 +1220,7 @@ class ProgramBuilder {
template <typename NAME>
ast::Variable* Var(const Source& source,
NAME&& name,
typ::Type type,
ast::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
@ -1236,7 +1236,7 @@ class ProgramBuilder {
/// @returns a constant `ast::Variable` with the given name and type
template <typename NAME>
ast::Variable* Const(NAME&& name,
typ::Type type,
ast::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
type = ty.MaybeCreateTypename(type);
@ -1254,7 +1254,7 @@ class ProgramBuilder {
template <typename NAME>
ast::Variable* Const(const Source& source,
NAME&& name,
typ::Type type,
ast::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
type = ty.MaybeCreateTypename(type);
@ -1269,7 +1269,7 @@ class ProgramBuilder {
/// @returns a constant `ast::Variable` with the given name and type
template <typename NAME>
ast::Variable* Param(NAME&& name,
typ::Type type,
ast::Type* type,
ast::DecorationList decorations = {}) {
type = ty.MaybeCreateTypename(type);
return create<ast::Variable>(Sym(std::forward<NAME>(name)),
@ -1285,7 +1285,7 @@ class ProgramBuilder {
template <typename NAME>
ast::Variable* Param(const Source& source,
NAME&& name,
typ::Type type,
ast::Type* type,
ast::DecorationList decorations = {}) {
type = ty.MaybeCreateTypename(type);
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)),
@ -1302,7 +1302,7 @@ class ProgramBuilder {
/// global variable with the ast::Module.
template <typename NAME>
ast::Variable* Global(NAME&& name,
typ::Type type,
ast::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
@ -1323,7 +1323,7 @@ class ProgramBuilder {
template <typename NAME>
ast::Variable* Global(const Source& source,
NAME&& name,
typ::Type type,
ast::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
@ -1476,7 +1476,7 @@ class ProgramBuilder {
ast::Function* Func(const Source& source,
NAME&& name,
ast::VariableList params,
typ::Type type,
ast::Type* type,
ast::StatementList body,
ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) {
@ -1501,7 +1501,7 @@ class ProgramBuilder {
template <typename NAME>
ast::Function* Func(NAME&& name,
ast::VariableList params,
typ::Type type,
ast::Type* type,
ast::StatementList body,
ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) {
@ -1588,7 +1588,7 @@ class ProgramBuilder {
template <typename NAME>
ast::StructMember* Member(const Source& source,
NAME&& name,
typ::Type type,
ast::Type* type,
ast::DecorationList decorations = {}) {
type = ty.MaybeCreateTypename(type);
return create<ast::StructMember>(source, Sym(std::forward<NAME>(name)),
@ -1602,7 +1602,7 @@ class ProgramBuilder {
/// @returns the struct member pointer
template <typename NAME>
ast::StructMember* Member(NAME&& name,
typ::Type type,
ast::Type* type,
ast::DecorationList decorations = {}) {
type = ty.MaybeCreateTypename(type);
return create<ast::StructMember>(source_, Sym(std::forward<NAME>(name)),
@ -1615,7 +1615,7 @@ class ProgramBuilder {
/// @param type the struct member type
/// @returns the struct member pointer
template <typename NAME>
ast::StructMember* Member(uint32_t offset, NAME&& name, typ::Type type) {
ast::StructMember* Member(uint32_t offset, NAME&& name, ast::Type* type) {
type = ty.MaybeCreateTypename(type);
return create<ast::StructMember>(
source_, Sym(std::forward<NAME>(name)), type,

View File

@ -218,7 +218,7 @@ TEST_F(SpvParserTest_Composite_Construct, Struct) {
VariableConst{
x_1
none
__struct_S
__type_name_S
{
TypeConstructor[not set]{
__type_name_S
@ -835,14 +835,14 @@ TEST_F(SpvParserTest_CompositeInsert, Struct) {
Variable{
x_35
function
__struct_S
__type_name_S
}
}
VariableDeclStatement{
VariableConst{
x_1
none
__struct_S
__type_name_S
{
Identifier[not set]{x_35}
}
@ -852,7 +852,7 @@ VariableDeclStatement{
Variable{
x_2_1
function
__struct_S
__type_name_S
{
Identifier[not set]{x_1}
}
@ -869,7 +869,7 @@ VariableDeclStatement{
VariableConst{
x_2
none
__struct_S
__type_name_S
{
Identifier[not set]{x_2_1}
}
@ -909,21 +909,21 @@ TEST_F(SpvParserTest_CompositeInsert, Struct_DifferOnlyInMemberName) {
Variable{
x_40
function
__struct_S_2
__type_name_S_2
}
}
VariableDeclStatement{
Variable{
x_41
function
__struct_S_2
__type_name_S_2
}
}
VariableDeclStatement{
VariableConst{
x_1
none
__struct_S_1
__type_name_S_1
{
Identifier[not set]{x_40}
}
@ -933,7 +933,7 @@ VariableDeclStatement{
Variable{
x_2_1
function
__struct_S_1
__type_name_S_1
{
Identifier[not set]{x_1}
}
@ -950,7 +950,7 @@ VariableDeclStatement{
VariableConst{
x_2
none
__struct_S_1
__type_name_S_1
{
Identifier[not set]{x_2_1}
}
@ -960,7 +960,7 @@ VariableDeclStatement{
VariableConst{
x_3
none
__struct_S_2
__type_name_S_2
{
Identifier[not set]{x_41}
}
@ -970,7 +970,7 @@ VariableDeclStatement{
Variable{
x_4_1
function
__struct_S_2
__type_name_S_2
{
Identifier[not set]{x_3}
}
@ -987,7 +987,7 @@ VariableDeclStatement{
VariableConst{
x_4
none
__struct_S_2
__type_name_S_2
{
Identifier[not set]{x_4_1}
}
@ -998,7 +998,7 @@ VariableDeclStatement{
Variable{
x_4_1
function
__struct_S_2
__type_name_S_2
{
Identifier[not set]{x_3}
}
@ -1015,7 +1015,7 @@ VariableDeclStatement{
VariableConst{
x_4
none
__struct_S_2
__type_name_S_2
{
Identifier[not set]{x_4_1}
}
@ -1066,14 +1066,14 @@ TEST_F(SpvParserTest_CompositeInsert, Struct_Array_Matrix_Vector) {
Variable{
x_37
function
__struct_S_1
__type_name_S_1
}
}
VariableDeclStatement{
VariableConst{
x_1
none
__struct_S_1
__type_name_S_1
{
Identifier[not set]{x_37}
}
@ -1083,7 +1083,7 @@ VariableDeclStatement{
Variable{
x_2_1
function
__struct_S_1
__type_name_S_1
{
Identifier[not set]{x_1}
}
@ -1109,7 +1109,7 @@ VariableDeclStatement{
VariableConst{
x_2
none
__struct_S_1
__type_name_S_1
{
Identifier[not set]{x_2_1}
}

View File

@ -806,7 +806,7 @@ TEST_F(SpvParserTest, RemapStorageBuffer_TypesAndVarDeclarations) {
Variable{
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
})"));
}

View File

@ -409,7 +409,7 @@ TEST_F(SpvParserTestMiscInstruction, OpUndef_InFunction_Struct) {
VariableConst{
x_11
none
__struct_S
__type_name_S
{
TypeConstructor[not set]{
__type_name_S

View File

@ -467,7 +467,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_Alias) {
Variable{
x_200
function
__alias_Arr__array__u32_2_stride_16
__type_name_Arr
{
TypeConstructor[not set]{
__type_name_Arr
@ -537,7 +537,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_Alias_Null) {
Variable{
x_200
function
__alias_Arr__array__u32_2_stride_16
__type_name_Arr
{
TypeConstructor[not set]{
__type_name_Arr
@ -572,7 +572,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer) {
Variable{
x_200
function
__struct_S
__type_name_S
{
TypeConstructor[not set]{
__type_name_S
@ -612,7 +612,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer_Null) {
Variable{
x_200
function
__struct_S
__type_name_S
{
TypeConstructor[not set]{
__type_name_S

View File

@ -233,7 +233,8 @@ bool AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode) {
// @param tp the type pair
// @returns the unwrapped type pair
typ::Type UnwrapIfNeeded(typ::Type tp) {
return typ::Type{tp.ast->UnwrapIfNeeded(), tp.sem->UnwrapIfNeeded()};
return typ::Type{tp.ast ? tp.ast->UnwrapIfNeeded() : nullptr,
tp.sem ? tp.sem->UnwrapIfNeeded() : nullptr};
}
} // namespace
@ -1037,7 +1038,7 @@ typ::Type ParserImpl::ConvertType(
return result;
}
void ParserImpl::AddConstructedType(Symbol name, typ::Type type) {
void ParserImpl::AddConstructedType(Symbol name, ast::NamedType* type) {
auto iter = constructed_types_.insert(name);
if (iter.second) {
builder_.AST().AddConstructedType(type);
@ -1539,7 +1540,7 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
return {};
}
ast::Expression* ParserImpl::MakeNullValue(typ::Type type) {
ast::Expression* ParserImpl::MakeNullValue(ast::Type* type) {
// TODO(dneto): Use the no-operands constructor syntax when it becomes
// available in Tint.
// https://github.com/gpuweb/gpuweb/issues/685
@ -1550,43 +1551,43 @@ ast::Expression* ParserImpl::MakeNullValue(typ::Type type) {
return nullptr;
}
auto original_type = type;
auto* original_type = type;
type = UnwrapIfNeeded(type);
if (type.ast->Is<ast::Bool>()) {
if (type->Is<ast::Bool>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::BoolLiteral>(Source{}, false));
}
if (type.ast->Is<ast::U32>()) {
if (type->Is<ast::U32>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::UintLiteral>(Source{}, 0u));
}
if (type.ast->Is<ast::I32>()) {
if (type->Is<ast::I32>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::SintLiteral>(Source{}, 0));
}
if (type.ast->Is<ast::F32>()) {
if (type->Is<ast::F32>()) {
return create<ast::ScalarConstructorExpression>(
Source{}, create<ast::FloatLiteral>(Source{}, 0.0f));
}
if (type.ast->Is<ast::TypeName>()) {
if (type->Is<ast::TypeName>()) {
// TODO(amaiorano): No type constructor for TypeName (yet?)
ast::ExpressionList ast_components;
return create<ast::TypeConstructorExpression>(Source{}, original_type,
std::move(ast_components));
}
if (auto vec_ty = typ::As<typ::Vector>(type)) {
if (auto* vec_ty = type->As<ast::Vector>()) {
ast::ExpressionList ast_components;
for (size_t i = 0; i < vec_ty->size(); ++i) {
ast_components.emplace_back(MakeNullValue(typ::Call_type(vec_ty)));
ast_components.emplace_back(MakeNullValue(vec_ty->type()));
}
return create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.MaybeCreateTypename(type),
std::move(ast_components));
}
if (auto mat_ty = typ::As<typ::Matrix>(type)) {
if (auto* mat_ty = type->As<ast::Matrix>()) {
// Matrix components are columns
auto column_ty = builder_.ty.vec(typ::Call_type(mat_ty), mat_ty->rows());
auto column_ty = builder_.ty.vec(mat_ty->type(), mat_ty->rows());
ast::ExpressionList ast_components;
for (size_t i = 0; i < mat_ty->columns(); ++i) {
ast_components.emplace_back(MakeNullValue(column_ty));
@ -1595,18 +1596,18 @@ ast::Expression* ParserImpl::MakeNullValue(typ::Type type) {
Source{}, builder_.ty.MaybeCreateTypename(type),
std::move(ast_components));
}
if (auto arr_ty = typ::As<typ::Array>(type)) {
if (auto* arr_ty = type->As<ast::Array>()) {
ast::ExpressionList ast_components;
for (size_t i = 0; i < arr_ty->size(); ++i) {
ast_components.emplace_back(MakeNullValue(typ::Call_type(arr_ty)));
ast_components.emplace_back(MakeNullValue(arr_ty->type()));
}
return create<ast::TypeConstructorExpression>(
Source{}, builder_.ty.MaybeCreateTypename(original_type),
std::move(ast_components));
}
if (auto struct_ty = typ::As<typ::Struct>(type)) {
if (auto* struct_ty = type->As<ast::Struct>()) {
ast::ExpressionList ast_components;
for (auto* member : struct_ty.ast->members()) {
for (auto* member : struct_ty->members()) {
ast_components.emplace_back(MakeNullValue(member->type()));
}
return create<ast::TypeConstructorExpression>(

View File

@ -334,7 +334,7 @@ class ParserImpl : Reader {
/// Creates an AST expression node for the null value for the given type.
/// @param type the AST type
/// @returns a new expression
ast::Expression* MakeNullValue(typ::Type type);
ast::Expression* MakeNullValue(ast::Type* type);
/// Make a typed expression for the null value for the given type.
/// @param type the AST type
@ -598,7 +598,7 @@ class ParserImpl : Reader {
/// Adds `type` as a constructed type if it hasn't been added yet.
/// @param name the type's unique name
/// @param type the type to add
void AddConstructedType(Symbol name, typ::Type type);
void AddConstructedType(Symbol name, ast::NamedType* type);
/// Creates a new `ast::Node` owned by the ProgramBuilder.
/// @param args the arguments to pass to the type constructor

View File

@ -1408,7 +1408,7 @@ TEST_F(SpvModuleScopeVarParserTest, StructInitializer) {
EXPECT_THAT(module_str, HasSubstr(R"(Variable{
x_200
private
__struct_S
__type_name_S
{
TypeConstructor[not set]{
__type_name_S
@ -1437,7 +1437,7 @@ TEST_F(SpvModuleScopeVarParserTest, StructNullInitializer) {
EXPECT_THAT(module_str, HasSubstr(R"(Variable{
x_200
private
__struct_S
__type_name_S
{
TypeConstructor[not set]{
__type_name_S
@ -1466,7 +1466,7 @@ TEST_F(SpvModuleScopeVarParserTest, StructUndefInitializer) {
EXPECT_THAT(module_str, HasSubstr(R"(Variable{
x_200
private
__struct_S
__type_name_S
{
TypeConstructor[not set]{
__type_name_S
@ -1553,7 +1553,7 @@ TEST_F(SpvModuleScopeVarParserTest, DescriptorGroupDecoration_Valid) {
}
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
})"))
<< module_str;
}
@ -1607,7 +1607,7 @@ TEST_F(SpvModuleScopeVarParserTest, BindingDecoration_Valid) {
}
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
})"))
<< module_str;
}
@ -1664,7 +1664,7 @@ TEST_F(SpvModuleScopeVarParserTest,
Variable{
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
}
)")) << module_str;
}
@ -1693,7 +1693,7 @@ TEST_F(SpvModuleScopeVarParserTest, ColMajorDecoration_Dropped) {
Variable{
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
}
})")) << module_str;
}
@ -1722,7 +1722,7 @@ TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
Variable{
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
}
})")) << module_str;
}
@ -1772,7 +1772,7 @@ TEST_F(SpvModuleScopeVarParserTest, StorageBuffer_NonWritable_AllMembers) {
Variable{
myvar
storage
__access_control_read_only__struct_S
__access_control_read_only__type_name_S
}
})")) << module_str;
}
@ -1801,7 +1801,7 @@ TEST_F(SpvModuleScopeVarParserTest, StorageBuffer_NonWritable_NotAllMembers) {
Variable{
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
}
})")) << module_str;
}
@ -1833,7 +1833,7 @@ TEST_F(
Variable{
myvar
storage
__access_control_read_write__struct_S
__access_control_read_write__type_name_S
}
})")) << module_str;
}

View File

@ -401,7 +401,7 @@ Expect<bool> ParserImpl::expect_global_decl() {
if (!expect("type alias", Token::Type::kSemicolon))
return Failure::kErrored;
builder_.AST().AddConstructedType(ta.value);
builder_.AST().AddConstructedType(const_cast<ast::Alias*>(ta.value.ast));
return true;
}
@ -415,7 +415,8 @@ Expect<bool> ParserImpl::expect_global_decl() {
register_constructed(
builder_.Symbols().NameFor(str.value->impl()->name()), str.value);
builder_.AST().AddConstructedType(str.value);
builder_.AST().AddConstructedType(
const_cast<ast::Struct*>(str.value.ast));
return true;
}
@ -561,7 +562,8 @@ Maybe<ParserImpl::VarDeclInfo> ParserImpl::variable_decl() {
if (decl.errored)
return Failure::kErrored;
if (decl->type->UnwrapAll()->is_handle()) {
if ((decl->type.sem && decl->type.sem->UnwrapAll()->is_handle()) ||
(decl->type.ast && decl->type.ast->UnwrapAll()->is_handle())) {
// handle types implicitly have the `UniformConstant` storage class.
if (explicit_sc.matched) {
return add_error(
@ -960,7 +962,7 @@ Maybe<ast::StorageClass> ParserImpl::variable_storage_decoration() {
// type_alias
// : TYPE IDENT EQUAL type_decl
Maybe<typ::Type> ParserImpl::type_alias() {
Maybe<typ::Alias> ParserImpl::type_alias() {
auto t = peek();
if (!t.IsType())
return Failure::kNoMatch;
@ -1234,7 +1236,7 @@ Expect<ast::StorageClass> ParserImpl::expect_storage_class(
// struct_decl
// : struct_decoration_decl* STRUCT IDENT struct_body_decl
Maybe<sem::StructType*> ParserImpl::struct_decl(ast::DecorationList& decos) {
Maybe<typ::Struct> ParserImpl::struct_decl(ast::DecorationList& decos) {
auto t = peek();
auto source = t.source();
@ -1250,8 +1252,9 @@ Maybe<sem::StructType*> ParserImpl::struct_decl(ast::DecorationList& decos) {
return Failure::kErrored;
auto sym = builder_.Symbols().Register(name.value);
return create<sem::StructType>(create<ast::Struct>(
source, sym, std::move(body.value), std::move(decos)));
auto* str =
create<ast::Struct>(source, sym, std::move(body.value), std::move(decos));
return typ::Struct{str, create<sem::StructType>(str)};
}
// struct_body_decl

View File

@ -398,7 +398,7 @@ class ParserImpl {
Maybe<ast::StorageClass> variable_storage_decoration();
/// Parses a `type_alias` grammar element
/// @returns the type alias or nullptr on error
Maybe<typ::Type> type_alias();
Maybe<typ::Alias> type_alias();
/// Parses a `type_decl` grammar element
/// @returns the parsed Type or nullptr if none matched.
Maybe<typ::Type> type_decl();
@ -415,7 +415,7 @@ class ParserImpl {
/// `struct_decoration_decl*` provided as `decos`.
/// @returns the struct type or nullptr on error
/// @param decos the list of decorations for the struct declaration.
Maybe<sem::StructType*> struct_decl(ast::DecorationList& decos);
Maybe<typ::Struct> struct_decl(ast::DecorationList& decos);
/// Parses a `struct_body_decl` grammar element, erroring on parse failure.
/// @returns the struct members
Expect<ast::StructMemberList> expect_struct_body_decl();

View File

@ -28,8 +28,8 @@ TEST_F(ParserImplTest, ConstExpr_TypeDecl) {
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* t = e->As<ast::TypeConstructorExpression>();
ASSERT_TRUE(t->type()->Is<sem::Vector>());
EXPECT_EQ(t->type()->As<sem::Vector>()->size(), 2u);
ASSERT_TRUE(t->type()->Is<ast::Vector>());
EXPECT_EQ(t->type()->As<ast::Vector>()->size(), 2u);
ASSERT_EQ(t->values().size(), 2u);
auto& v = t->values();
@ -56,8 +56,8 @@ TEST_F(ParserImplTest, ConstExpr_TypeDecl_Empty) {
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* t = e->As<ast::TypeConstructorExpression>();
ASSERT_TRUE(t->type()->Is<sem::Vector>());
EXPECT_EQ(t->type()->As<sem::Vector>()->size(), 2u);
ASSERT_TRUE(t->type()->Is<ast::Vector>());
EXPECT_EQ(t->type()->As<ast::Vector>()->size(), 2u);
ASSERT_EQ(t->values().size(), 0u);
}
@ -71,8 +71,8 @@ TEST_F(ParserImplTest, ConstExpr_TypeDecl_TrailingComma) {
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* t = e->As<ast::TypeConstructorExpression>();
ASSERT_TRUE(t->type()->Is<sem::Vector>());
EXPECT_EQ(t->type()->As<sem::Vector>()->size(), 2u);
ASSERT_TRUE(t->type()->Is<ast::Vector>());
EXPECT_EQ(t->type()->As<ast::Vector>()->size(), 2u);
ASSERT_EQ(t->values().size(), 2u);
ASSERT_TRUE(t->values()[0]->Is<ast::ScalarConstructorExpression>());

View File

@ -34,14 +34,14 @@ TEST_F(ParserImplTest, FunctionDecl) {
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
EXPECT_TRUE(f->return_type()->Is<ast::Void>());
ASSERT_EQ(f->params().size(), 2u);
EXPECT_EQ(f->params()[0]->symbol(), p->builder().Symbols().Get("a"));
EXPECT_EQ(f->params()[1]->symbol(), p->builder().Symbols().Get("b"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
EXPECT_TRUE(f->return_type()->Is<ast::Void>());
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
@ -62,10 +62,8 @@ TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
EXPECT_TRUE(f->return_type()->Is<ast::Void>());
ASSERT_EQ(f->params().size(), 0u);
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
auto& decorations = f->decorations();
ASSERT_EQ(decorations.size(), 1u);
@ -100,10 +98,8 @@ fn main() { return; })");
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
EXPECT_TRUE(f->return_type()->Is<ast::Void>());
ASSERT_EQ(f->params().size(), 0u);
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
auto& decorations = f->decorations();
ASSERT_EQ(decorations.size(), 2u);
@ -145,10 +141,8 @@ fn main() { return; })");
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
EXPECT_TRUE(f->return_type()->Is<ast::Void>());
ASSERT_EQ(f->params().size(), 0u);
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::Void>());
auto& decos = f->decorations();
ASSERT_EQ(decos.size(), 2u);
@ -187,7 +181,7 @@ TEST_F(ParserImplTest, FunctionDecl_ReturnTypeDecorationList) {
EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
ASSERT_NE(f->return_type(), nullptr);
EXPECT_TRUE(f->return_type()->Is<sem::F32>());
EXPECT_TRUE(f->return_type()->Is<ast::F32>());
ASSERT_EQ(f->params().size(), 0u);
auto& decorations = f->decorations();

View File

@ -33,8 +33,8 @@ TEST_F(ParserImplTest, GlobalConstantDecl) {
EXPECT_TRUE(e->is_const());
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
ASSERT_NE(e->declared_type(), nullptr);
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
ASSERT_NE(e->type(), nullptr);
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->source().range.begin.line, 1u);
EXPECT_EQ(e->source().range.begin.column, 5u);
@ -114,8 +114,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_WithId) {
EXPECT_TRUE(e->is_const());
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
ASSERT_NE(e->declared_type(), nullptr);
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
ASSERT_NE(e->type(), nullptr);
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->source().range.begin.line, 1u);
EXPECT_EQ(e->source().range.begin.column, 21u);

View File

@ -84,10 +84,10 @@ TEST_F(ParserImplTest, GlobalDecl_TypeAlias) {
auto program = p->program();
ASSERT_EQ(program.AST().ConstructedTypes().size(), 1u);
ASSERT_TRUE(program.AST().ConstructedTypes()[0]->Is<sem::Alias>());
ASSERT_TRUE(program.AST().ConstructedTypes()[0]->Is<ast::Alias>());
EXPECT_EQ(
program.Symbols().NameFor(
program.AST().ConstructedTypes()[0]->As<sem::Alias>()->symbol()),
program.AST().ConstructedTypes()[0]->As<ast::Alias>()->symbol()),
"A");
}
@ -102,14 +102,16 @@ type B = A;)");
auto program = p->program();
ASSERT_EQ(program.AST().ConstructedTypes().size(), 2u);
ASSERT_TRUE(program.AST().ConstructedTypes()[0]->Is<sem::StructType>());
auto* str = program.AST().ConstructedTypes()[0]->As<sem::StructType>();
EXPECT_EQ(str->impl()->name(), program.Symbols().Get("A"));
ASSERT_TRUE(program.AST().ConstructedTypes()[0]->Is<ast::Struct>());
auto* str = program.AST().ConstructedTypes()[0]->As<ast::Struct>();
EXPECT_EQ(str->name(), program.Symbols().Get("A"));
ASSERT_TRUE(program.AST().ConstructedTypes()[1]->Is<sem::Alias>());
auto* alias = program.AST().ConstructedTypes()[1]->As<sem::Alias>();
ASSERT_TRUE(program.AST().ConstructedTypes()[1]->Is<ast::Alias>());
auto* alias = program.AST().ConstructedTypes()[1]->As<ast::Alias>();
EXPECT_EQ(alias->symbol(), program.Symbols().Get("B"));
EXPECT_EQ(alias->type(), str);
auto* tn = alias->type()->As<ast::TypeName>();
EXPECT_NE(tn, nullptr);
EXPECT_EQ(tn->name(), str->name());
}
TEST_F(ParserImplTest, GlobalDecl_TypeAlias_Invalid) {
@ -163,13 +165,13 @@ TEST_F(ParserImplTest, GlobalDecl_ParsesStruct) {
auto program = p->program();
ASSERT_EQ(program.AST().ConstructedTypes().size(), 1u);
auto t = program.AST().ConstructedTypes()[0];
auto* t = program.AST().ConstructedTypes()[0];
ASSERT_NE(t, nullptr);
ASSERT_TRUE(t->Is<sem::StructType>());
ASSERT_TRUE(t->Is<ast::Struct>());
auto* str = t->As<sem::StructType>();
EXPECT_EQ(str->impl()->name(), program.Symbols().Get("A"));
EXPECT_EQ(str->impl()->members().size(), 2u);
auto* str = t->As<ast::Struct>();
EXPECT_EQ(str->name(), program.Symbols().Get("A"));
EXPECT_EQ(str->members().size(), 2u);
}
TEST_F(ParserImplTest, GlobalDecl_Struct_WithStride) {
@ -181,18 +183,18 @@ TEST_F(ParserImplTest, GlobalDecl_Struct_WithStride) {
auto program = p->program();
ASSERT_EQ(program.AST().ConstructedTypes().size(), 1u);
auto t = program.AST().ConstructedTypes()[0];
auto* t = program.AST().ConstructedTypes()[0];
ASSERT_NE(t, nullptr);
ASSERT_TRUE(t->Is<sem::StructType>());
ASSERT_TRUE(t->Is<ast::Struct>());
auto* str = t->As<sem::StructType>();
EXPECT_EQ(str->impl()->name(), program.Symbols().Get("A"));
EXPECT_EQ(str->impl()->members().size(), 1u);
auto* str = t->As<ast::Struct>();
EXPECT_EQ(str->name(), program.Symbols().Get("A"));
EXPECT_EQ(str->members().size(), 1u);
EXPECT_FALSE(str->IsBlockDecorated());
const auto ty = str->impl()->members()[0]->type();
ASSERT_TRUE(ty->Is<sem::ArrayType>());
const auto* arr = ty->As<sem::ArrayType>();
const auto* ty = str->members()[0]->type();
ASSERT_TRUE(ty->Is<ast::Array>());
const auto* arr = ty->As<ast::Array>();
ASSERT_EQ(arr->decorations().size(), 1u);
auto* stride = arr->decorations()[0];
@ -208,13 +210,13 @@ TEST_F(ParserImplTest, GlobalDecl_Struct_WithDecoration) {
auto program = p->program();
ASSERT_EQ(program.AST().ConstructedTypes().size(), 1u);
auto t = program.AST().ConstructedTypes()[0];
auto* t = program.AST().ConstructedTypes()[0];
ASSERT_NE(t, nullptr);
ASSERT_TRUE(t->Is<sem::StructType>());
ASSERT_TRUE(t->Is<ast::Struct>());
auto* str = t->As<sem::StructType>();
EXPECT_EQ(str->impl()->name(), program.Symbols().Get("A"));
EXPECT_EQ(str->impl()->members().size(), 1u);
auto* str = t->As<ast::Struct>();
EXPECT_EQ(str->name(), program.Symbols().Get("A"));
EXPECT_EQ(str->members().size(), 1u);
EXPECT_TRUE(str->IsBlockDecorated());
}

View File

@ -31,7 +31,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithoutConstructor) {
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kPrivate);
EXPECT_EQ(e->source().range.begin.line, 1u);
@ -54,7 +54,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) {
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kPrivate);
EXPECT_EQ(e->source().range.begin.line, 1u);
@ -79,8 +79,8 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithDecoration) {
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
ASSERT_NE(e->declared_type(), nullptr);
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
ASSERT_NE(e->type(), nullptr);
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniform);
EXPECT_EQ(e->source().range.begin.line, 1u);
@ -109,8 +109,8 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithDecoration_MulitpleGroups) {
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
ASSERT_NE(e->declared_type(), nullptr);
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
ASSERT_NE(e->type(), nullptr);
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniform);
EXPECT_EQ(e->source().range.begin.line, 1u);
@ -180,7 +180,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_SamplerImplicitStorageClass) {
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("s"));
EXPECT_TRUE(e->declared_type()->Is<sem::Sampler>());
EXPECT_TRUE(e->type()->Is<ast::Sampler>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniformConstant);
}
@ -196,7 +196,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_TextureImplicitStorageClass) {
ASSERT_NE(e.value, nullptr);
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("s"));
EXPECT_TRUE(e->declared_type()->UnwrapAll()->Is<sem::Texture>());
EXPECT_TRUE(e->type()->UnwrapAll()->Is<ast::Texture>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kUniformConstant);
}
@ -210,7 +210,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_StorageClassIn_Deprecated) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kInput);
EXPECT_EQ(
@ -231,7 +231,7 @@ TEST_F(ParserImplTest, GlobalVariableDecl_StorageClassOut_Deprecated) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_EQ(e->symbol(), p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->declared_type()->Is<sem::F32>());
EXPECT_TRUE(e->type()->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class(), ast::StorageClass::kOutput);
EXPECT_EQ(

View File

@ -22,15 +22,13 @@ namespace {
TEST_F(ParserImplTest, ParamList_Single) {
auto p = parser("a : i32");
auto* i32 = p->builder().create<sem::I32>();
auto e = p->expect_param_list();
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
EXPECT_EQ(e.value.size(), 1u);
EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("a"));
EXPECT_EQ(e.value[0]->declared_type(), i32);
EXPECT_TRUE(e.value[0]->type()->Is<ast::I32>());
EXPECT_TRUE(e.value[0]->is_const());
ASSERT_EQ(e.value[0]->source().range.begin.line, 1u);
@ -42,17 +40,13 @@ TEST_F(ParserImplTest, ParamList_Single) {
TEST_F(ParserImplTest, ParamList_Multiple) {
auto p = parser("a : i32, b: f32, c: vec2<f32>");
auto* i32 = p->builder().create<sem::I32>();
auto* f32 = p->builder().create<sem::F32>();
auto* vec2 = p->builder().create<sem::Vector>(f32, 2);
auto e = p->expect_param_list();
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
EXPECT_EQ(e.value.size(), 3u);
EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("a"));
EXPECT_EQ(e.value[0]->declared_type(), i32);
EXPECT_TRUE(e.value[0]->type()->Is<ast::I32>());
EXPECT_TRUE(e.value[0]->is_const());
ASSERT_EQ(e.value[0]->source().range.begin.line, 1u);
@ -61,7 +55,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
ASSERT_EQ(e.value[0]->source().range.end.column, 2u);
EXPECT_EQ(e.value[1]->symbol(), p->builder().Symbols().Get("b"));
EXPECT_EQ(e.value[1]->declared_type(), f32);
EXPECT_TRUE(e.value[1]->type()->Is<ast::F32>());
EXPECT_TRUE(e.value[1]->is_const());
ASSERT_EQ(e.value[1]->source().range.begin.line, 1u);
@ -70,7 +64,9 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
ASSERT_EQ(e.value[1]->source().range.end.column, 11u);
EXPECT_EQ(e.value[2]->symbol(), p->builder().Symbols().Get("c"));
EXPECT_EQ(e.value[2]->declared_type(), vec2);
ASSERT_TRUE(e.value[2]->type()->Is<ast::Vector>());
ASSERT_TRUE(e.value[2]->type()->As<ast::Vector>()->type()->Is<ast::F32>());
EXPECT_EQ(e.value[2]->type()->As<ast::Vector>()->size(), 2u);
EXPECT_TRUE(e.value[2]->is_const());
ASSERT_EQ(e.value[2]->source().range.begin.line, 1u);
@ -100,16 +96,15 @@ TEST_F(ParserImplTest, ParamList_Decorations) {
"[[builtin(position)]] coord : vec4<f32>, "
"[[location(1)]] loc1 : f32");
auto* f32 = p->builder().create<sem::F32>();
auto* vec4 = p->builder().create<sem::Vector>(f32, 4);
auto e = p->expect_param_list();
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
ASSERT_EQ(e.value.size(), 2u);
EXPECT_EQ(e.value[0]->symbol(), p->builder().Symbols().Get("coord"));
EXPECT_EQ(e.value[0]->declared_type(), vec4);
ASSERT_TRUE(e.value[0]->type()->Is<ast::Vector>());
EXPECT_TRUE(e.value[0]->type()->As<ast::Vector>()->type()->Is<ast::F32>());
EXPECT_EQ(e.value[0]->type()->As<ast::Vector>()->size(), 4u);
EXPECT_TRUE(e.value[0]->is_const());
auto decos0 = e.value[0]->decorations();
ASSERT_EQ(decos0.size(), 1u);
@ -123,7 +118,7 @@ TEST_F(ParserImplTest, ParamList_Decorations) {
ASSERT_EQ(e.value[0]->source().range.end.column, 28u);
EXPECT_EQ(e.value[1]->symbol(), p->builder().Symbols().Get("loc1"));
EXPECT_EQ(e.value[1]->declared_type(), f32);
EXPECT_TRUE(e.value[1]->type()->Is<ast::F32>());
EXPECT_TRUE(e.value[1]->is_const());
auto decos1 = e.value[1]->decorations();
ASSERT_EQ(decos1.size(), 1u);

View File

@ -141,7 +141,9 @@ TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructConstructor_Empty) {
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* constructor = e->As<ast::TypeConstructorExpression>();
EXPECT_EQ(constructor->type(), p->get_constructed("S"));
ASSERT_TRUE(constructor->type()->Is<ast::TypeName>());
EXPECT_EQ(constructor->type()->As<ast::TypeName>()->name(),
p->builder().Symbols().Get("S"));
auto values = constructor->values();
ASSERT_EQ(values.size(), 0u);
@ -164,7 +166,9 @@ TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructConstructor_NotEmpty) {
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* constructor = e->As<ast::TypeConstructorExpression>();
EXPECT_EQ(constructor->type(), p->get_constructed("S"));
ASSERT_TRUE(constructor->type()->Is<ast::TypeName>());
EXPECT_EQ(constructor->type()->As<ast::TypeName>()->name(),
p->builder().Symbols().Get("S"));
auto values = constructor->values();
ASSERT_EQ(values.size(), 2u);
@ -237,8 +241,6 @@ TEST_F(ParserImplTest, PrimaryExpression_ParenExpr_InvalidExpr) {
TEST_F(ParserImplTest, PrimaryExpression_Cast) {
auto p = parser("f32(1)");
auto* f32 = p->builder().create<sem::F32>();
auto e = p->primary_expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
@ -248,7 +250,7 @@ TEST_F(ParserImplTest, PrimaryExpression_Cast) {
ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
auto* c = e->As<ast::TypeConstructorExpression>();
ASSERT_EQ(c->type(), f32);
ASSERT_TRUE(c->type()->Is<ast::F32>());
ASSERT_EQ(c->values().size(), 1u);
ASSERT_TRUE(c->values()[0]->Is<ast::ConstructorExpression>());
@ -258,8 +260,6 @@ TEST_F(ParserImplTest, PrimaryExpression_Cast) {
TEST_F(ParserImplTest, PrimaryExpression_Bitcast) {
auto p = parser("bitcast<f32>(1)");
auto* f32 = p->builder().create<sem::F32>();
auto e = p->primary_expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
@ -268,8 +268,7 @@ TEST_F(ParserImplTest, PrimaryExpression_Bitcast) {
ASSERT_TRUE(e->Is<ast::BitcastExpression>());
auto* c = e->As<ast::BitcastExpression>();
ASSERT_EQ(c->type(), f32);
ASSERT_TRUE(c->type()->Is<ast::F32>());
ASSERT_TRUE(c->expr()->Is<ast::ConstructorExpression>());
ASSERT_TRUE(c->expr()->Is<ast::ScalarConstructorExpression>());
}

View File

@ -23,7 +23,6 @@ TEST_F(ParserImplTest, StructBodyDecl_Parses) {
auto p = parser("{a : i32;}");
auto& builder = p->builder();
auto* i32 = builder.create<sem::I32>();
auto m = p->expect_struct_body_decl();
ASSERT_FALSE(p->has_error());
@ -32,7 +31,7 @@ TEST_F(ParserImplTest, StructBodyDecl_Parses) {
const auto* mem = m.value[0];
EXPECT_EQ(mem->symbol(), builder.Symbols().Get("a"));
EXPECT_EQ(mem->type(), i32);
EXPECT_TRUE(mem->type()->Is<ast::I32>());
EXPECT_EQ(mem->decorations().size(), 0u);
}

View File

@ -23,7 +23,6 @@ TEST_F(ParserImplTest, StructMember_Parses) {
auto p = parser("a : i32;");
auto& builder = p->builder();
auto i32 = builder.ty.i32();
auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored);
@ -36,18 +35,17 @@ TEST_F(ParserImplTest, StructMember_Parses) {
ASSERT_NE(m.value, nullptr);
EXPECT_EQ(m->symbol(), builder.Symbols().Get("a"));
EXPECT_EQ(m->type(), i32);
EXPECT_TRUE(m->type()->Is<ast::I32>());
EXPECT_EQ(m->decorations().size(), 0u);
EXPECT_EQ(m->source().range, (Source::Range{{1u, 1u}, {1u, 2u}}));
EXPECT_EQ(m->type().ast->source().range, (Source::Range{{1u, 5u}, {1u, 8u}}));
EXPECT_EQ(m->type()->source().range, (Source::Range{{1u, 5u}, {1u, 8u}}));
}
TEST_F(ParserImplTest, StructMember_ParsesWithOffsetDecoration_DEPRECATED) {
auto p = parser("[[offset(2)]] a : i32;");
auto& builder = p->builder();
auto i32 = builder.ty.i32();
auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored);
@ -60,7 +58,7 @@ TEST_F(ParserImplTest, StructMember_ParsesWithOffsetDecoration_DEPRECATED) {
ASSERT_NE(m.value, nullptr);
EXPECT_EQ(m->symbol(), builder.Symbols().Get("a"));
EXPECT_EQ(m->type(), i32);
EXPECT_TRUE(m->type()->Is<ast::I32>());
EXPECT_EQ(m->decorations().size(), 1u);
EXPECT_TRUE(m->decorations()[0]->Is<ast::StructMemberOffsetDecoration>());
EXPECT_EQ(
@ -68,15 +66,13 @@ TEST_F(ParserImplTest, StructMember_ParsesWithOffsetDecoration_DEPRECATED) {
2u);
EXPECT_EQ(m->source().range, (Source::Range{{1u, 15u}, {1u, 16u}}));
EXPECT_EQ(m->type().ast->source().range,
(Source::Range{{1u, 19u}, {1u, 22u}}));
EXPECT_EQ(m->type()->source().range, (Source::Range{{1u, 19u}, {1u, 22u}}));
}
TEST_F(ParserImplTest, StructMember_ParsesWithAlignDecoration) {
auto p = parser("[[align(2)]] a : i32;");
auto& builder = p->builder();
auto i32 = builder.ty.i32();
auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored);
@ -89,22 +85,20 @@ TEST_F(ParserImplTest, StructMember_ParsesWithAlignDecoration) {
ASSERT_NE(m.value, nullptr);
EXPECT_EQ(m->symbol(), builder.Symbols().Get("a"));
EXPECT_EQ(m->type(), i32);
EXPECT_TRUE(m->type()->Is<ast::I32>());
EXPECT_EQ(m->decorations().size(), 1u);
EXPECT_TRUE(m->decorations()[0]->Is<ast::StructMemberAlignDecoration>());
EXPECT_EQ(
m->decorations()[0]->As<ast::StructMemberAlignDecoration>()->align(), 2u);
EXPECT_EQ(m->source().range, (Source::Range{{1u, 14u}, {1u, 15u}}));
EXPECT_EQ(m->type().ast->source().range,
(Source::Range{{1u, 18u}, {1u, 21u}}));
EXPECT_EQ(m->type()->source().range, (Source::Range{{1u, 18u}, {1u, 21u}}));
}
TEST_F(ParserImplTest, StructMember_ParsesWithSizeDecoration) {
auto p = parser("[[size(2)]] a : i32;");
auto& builder = p->builder();
auto i32 = builder.ty.i32();
auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored);
@ -117,22 +111,20 @@ TEST_F(ParserImplTest, StructMember_ParsesWithSizeDecoration) {
ASSERT_NE(m.value, nullptr);
EXPECT_EQ(m->symbol(), builder.Symbols().Get("a"));
EXPECT_EQ(m->type(), i32);
EXPECT_TRUE(m->type()->Is<ast::I32>());
EXPECT_EQ(m->decorations().size(), 1u);
EXPECT_TRUE(m->decorations()[0]->Is<ast::StructMemberSizeDecoration>());
EXPECT_EQ(m->decorations()[0]->As<ast::StructMemberSizeDecoration>()->size(),
2u);
EXPECT_EQ(m->source().range, (Source::Range{{1u, 13u}, {1u, 14u}}));
EXPECT_EQ(m->type().ast->source().range,
(Source::Range{{1u, 17u}, {1u, 20u}}));
EXPECT_EQ(m->type()->source().range, (Source::Range{{1u, 17u}, {1u, 20u}}));
}
TEST_F(ParserImplTest, StructMember_ParsesWithDecoration) {
auto p = parser("[[size(2)]] a : i32;");
auto& builder = p->builder();
auto i32 = builder.ty.i32();
auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored);
@ -145,15 +137,14 @@ TEST_F(ParserImplTest, StructMember_ParsesWithDecoration) {
ASSERT_NE(m.value, nullptr);
EXPECT_EQ(m->symbol(), builder.Symbols().Get("a"));
EXPECT_EQ(m->type(), i32);
EXPECT_TRUE(m->type()->Is<ast::I32>());
EXPECT_EQ(m->decorations().size(), 1u);
EXPECT_TRUE(m->decorations()[0]->Is<ast::StructMemberSizeDecoration>());
EXPECT_EQ(m->decorations()[0]->As<ast::StructMemberSizeDecoration>()->size(),
2u);
EXPECT_EQ(m->source().range, (Source::Range{{1u, 13u}, {1u, 14u}}));
EXPECT_EQ(m->type().ast->source().range,
(Source::Range{{1u, 17u}, {1u, 20u}}));
EXPECT_EQ(m->type()->source().range, (Source::Range{{1u, 17u}, {1u, 20u}}));
}
TEST_F(ParserImplTest, StructMember_ParsesWithMultipleDecorations) {
@ -161,7 +152,6 @@ TEST_F(ParserImplTest, StructMember_ParsesWithMultipleDecorations) {
[[align(4)]] a : i32;)");
auto& builder = p->builder();
auto i32 = builder.ty.i32();
auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored);
@ -174,7 +164,7 @@ TEST_F(ParserImplTest, StructMember_ParsesWithMultipleDecorations) {
ASSERT_NE(m.value, nullptr);
EXPECT_EQ(m->symbol(), builder.Symbols().Get("a"));
EXPECT_EQ(m->type(), i32);
EXPECT_TRUE(m->type()->Is<ast::I32>());
EXPECT_EQ(m->decorations().size(), 2u);
EXPECT_TRUE(m->decorations()[0]->Is<ast::StructMemberSizeDecoration>());
EXPECT_EQ(m->decorations()[0]->As<ast::StructMemberSizeDecoration>()->size(),
@ -184,8 +174,7 @@ TEST_F(ParserImplTest, StructMember_ParsesWithMultipleDecorations) {
m->decorations()[1]->As<ast::StructMemberAlignDecoration>()->align(), 4u);
EXPECT_EQ(m->source().range, (Source::Range{{2u, 14u}, {2u, 15u}}));
EXPECT_EQ(m->type().ast->source().range,
(Source::Range{{2u, 18u}, {2u, 21u}}));
EXPECT_EQ(m->type()->source().range, (Source::Range{{2u, 18u}, {2u, 21u}}));
}
TEST_F(ParserImplTest, StructMember_InvalidDecoration) {

View File

@ -308,7 +308,7 @@ namespace ArrayStrideTests {
namespace {
struct Params {
create_type_func_ptr create_el_type;
create_ast_type_func_ptr create_el_type;
uint32_t stride;
bool should_pass;
};
@ -318,17 +318,16 @@ struct TestWithParams : ResolverTestWithParam<Params> {};
using ArrayStrideTest = TestWithParams;
TEST_P(ArrayStrideTest, All) {
auto& params = GetParam();
auto el_ty = params.create_el_type(ty);
auto* el_ty = params.create_el_type(ty);
std::stringstream ss;
ss << "el_ty: " << el_ty->FriendlyName(Symbols())
<< ", stride: " << params.stride
ss << "el_ty: " << FriendlyName(el_ty) << ", stride: " << params.stride
<< ", should_pass: " << params.should_pass;
SCOPED_TRACE(ss.str());
auto arr = ty.array(el_ty, 4, params.stride);
auto arr = ty.array(Source{{12, 34}}, el_ty, 4, params.stride);
Global(Source{{12, 34}}, "myarray", arr, ast::StorageClass::kInput);
Global("myarray", arr, ast::StorageClass::kInput);
if (params.should_pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -366,58 +365,58 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(
// Succeed because stride >= element size (while being multiple of
// element alignment)
Params{ty_u32, default_u32.size, true},
Params{ty_i32, default_i32.size, true},
Params{ty_f32, default_f32.size, true},
Params{ty_vec2<f32>, default_vec2.size, true},
Params{ast_u32, default_u32.size, true},
Params{ast_i32, default_i32.size, true},
Params{ast_f32, default_f32.size, true},
Params{ast_vec2<f32>, default_vec2.size, true},
// vec3's default size is not a multiple of its alignment
// Params{ty_vec3<f32>, default_vec3.size, true},
Params{ty_vec4<f32>, default_vec4.size, true},
Params{ty_mat2x2<f32>, default_mat2x2.size, true},
Params{ty_mat3x3<f32>, default_mat3x3.size, true},
Params{ty_mat4x4<f32>, default_mat4x4.size, true},
// Params{ast_vec3<f32>, default_vec3.size, true},
Params{ast_vec4<f32>, default_vec4.size, true},
Params{ast_mat2x2<f32>, default_mat2x2.size, true},
Params{ast_mat3x3<f32>, default_mat3x3.size, true},
Params{ast_mat4x4<f32>, default_mat4x4.size, true},
// Fail because stride is < element size
Params{ty_u32, default_u32.size - 1, false},
Params{ty_i32, default_i32.size - 1, false},
Params{ty_f32, default_f32.size - 1, false},
Params{ty_vec2<f32>, default_vec2.size - 1, false},
Params{ty_vec3<f32>, default_vec3.size - 1, false},
Params{ty_vec4<f32>, default_vec4.size - 1, false},
Params{ty_mat2x2<f32>, default_mat2x2.size - 1, false},
Params{ty_mat3x3<f32>, default_mat3x3.size - 1, false},
Params{ty_mat4x4<f32>, default_mat4x4.size - 1, false},
Params{ast_u32, default_u32.size - 1, false},
Params{ast_i32, default_i32.size - 1, false},
Params{ast_f32, default_f32.size - 1, false},
Params{ast_vec2<f32>, default_vec2.size - 1, false},
Params{ast_vec3<f32>, default_vec3.size - 1, false},
Params{ast_vec4<f32>, default_vec4.size - 1, false},
Params{ast_mat2x2<f32>, default_mat2x2.size - 1, false},
Params{ast_mat3x3<f32>, default_mat3x3.size - 1, false},
Params{ast_mat4x4<f32>, default_mat4x4.size - 1, false},
// Succeed because stride equals multiple of element alignment
Params{ty_u32, default_u32.align * 7, true},
Params{ty_i32, default_i32.align * 7, true},
Params{ty_f32, default_f32.align * 7, true},
Params{ty_vec2<f32>, default_vec2.align * 7, true},
Params{ty_vec3<f32>, default_vec3.align * 7, true},
Params{ty_vec4<f32>, default_vec4.align * 7, true},
Params{ty_mat2x2<f32>, default_mat2x2.align * 7, true},
Params{ty_mat3x3<f32>, default_mat3x3.align * 7, true},
Params{ty_mat4x4<f32>, default_mat4x4.align * 7, true},
Params{ast_u32, default_u32.align * 7, true},
Params{ast_i32, default_i32.align * 7, true},
Params{ast_f32, default_f32.align * 7, true},
Params{ast_vec2<f32>, default_vec2.align * 7, true},
Params{ast_vec3<f32>, default_vec3.align * 7, true},
Params{ast_vec4<f32>, default_vec4.align * 7, true},
Params{ast_mat2x2<f32>, default_mat2x2.align * 7, true},
Params{ast_mat3x3<f32>, default_mat3x3.align * 7, true},
Params{ast_mat4x4<f32>, default_mat4x4.align * 7, true},
// Fail because stride is not multiple of element alignment
Params{ty_u32, (default_u32.align - 1) * 7, false},
Params{ty_i32, (default_i32.align - 1) * 7, false},
Params{ty_f32, (default_f32.align - 1) * 7, false},
Params{ty_vec2<f32>, (default_vec2.align - 1) * 7, false},
Params{ty_vec3<f32>, (default_vec3.align - 1) * 7, false},
Params{ty_vec4<f32>, (default_vec4.align - 1) * 7, false},
Params{ty_mat2x2<f32>, (default_mat2x2.align - 1) * 7, false},
Params{ty_mat3x3<f32>, (default_mat3x3.align - 1) * 7, false},
Params{ty_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false}));
Params{ast_u32, (default_u32.align - 1) * 7, false},
Params{ast_i32, (default_i32.align - 1) * 7, false},
Params{ast_f32, (default_f32.align - 1) * 7, false},
Params{ast_vec2<f32>, (default_vec2.align - 1) * 7, false},
Params{ast_vec3<f32>, (default_vec3.align - 1) * 7, false},
Params{ast_vec4<f32>, (default_vec4.align - 1) * 7, false},
Params{ast_mat2x2<f32>, (default_mat2x2.align - 1) * 7, false},
Params{ast_mat3x3<f32>, (default_mat3x3.align - 1) * 7, false},
Params{ast_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false}));
TEST_F(ArrayStrideTest, MultipleDecorations) {
auto arr = ty.array(ty.i32(), 4,
auto arr = ty.array(Source{{12, 34}}, ty.i32(), 4,
{
create<ast::StrideDecoration>(4),
create<ast::StrideDecoration>(4),
});
Global(Source{{12, 34}}, "myarray", arr, ast::StorageClass::kInput);
Global("myarray", arr, ast::StorageClass::kInput);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),

View File

@ -40,6 +40,7 @@
#include "src/ast/storage_texture.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type_name.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/ast/vector.h"
@ -151,8 +152,17 @@ void Resolver::set_referenced_from_function_if_needed(VariableInfo* var,
}
bool Resolver::Resolve() {
if (builder_->Diagnostics().contains_errors()) {
return false;
}
bool result = ResolveInternal();
if (result && diagnostics_.contains_errors()) {
TINT_ICE(diagnostics_) << "resolving failed, but no error was raised";
return false;
}
// Even if resolving failed, create all the semantic nodes for information we
// did generate.
CreateSemanticNodes();
@ -169,14 +179,16 @@ bool Resolver::IsStorable(const sem::Type* type) {
if (auto* arr = type->As<sem::ArrayType>()) {
return IsStorable(arr->type());
}
if (auto* str = type->As<sem::StructType>()) {
for (const auto* member : str->impl()->members()) {
if (!IsStorable(member->type())) {
if (auto* str_ty = type->As<sem::StructType>()) {
if (auto* str = Structure(str_ty)) {
for (const auto* member : str->members) {
if (!IsStorable(member->Type())) {
return false;
}
}
return true;
}
}
return false;
}
@ -196,8 +208,12 @@ bool Resolver::IsHostShareable(const sem::Type* type) {
return IsHostShareable(arr->type());
}
if (auto* str = type->As<sem::StructType>()) {
for (auto* member : str->impl()->members()) {
if (!IsHostShareable(member->type())) {
auto* info = Structure(str);
if (!info) {
return false;
}
for (auto* member : info->members) {
if (!IsHostShareable(member->Type())) {
return false;
}
}
@ -225,11 +241,28 @@ bool Resolver::IsValidAssignment(const sem::Type* lhs, const sem::Type* rhs) {
bool Resolver::ResolveInternal() {
Mark(&builder_->AST());
auto register_named_type = [this](Symbol name, const sem::Type* type,
const Source& source) {
auto added = named_types_.emplace(name, type).second;
if (!added) {
diagnostics_.add_error("type with the name '" +
builder_->Symbols().NameFor(name) +
"' was already declared",
source);
return false;
}
return true;
};
// Process everything else in the order they appear in the module. This is
// necessary for validation of use-before-declaration.
for (auto* decl : builder_->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<sem::Type>()) {
if (!Type(ty)) {
if (auto* ty = decl->As<ast::NamedType>()) {
auto* sem_ty = Type(ty);
if (sem_ty == nullptr) {
return false;
}
if (!register_named_type(ty->name(), sem_ty, ty->source())) {
return false;
}
} else if (auto* func = decl->As<ast::Function>()) {
@ -249,6 +282,8 @@ bool Resolver::ResolveInternal() {
}
}
bool result = true;
for (auto* node : builder_->ASTNodes().Objects()) {
if (marked_.count(node) == 0) {
if (node->IsAnyOf<ast::AccessDecoration, ast::StrideDecoration,
@ -268,10 +303,11 @@ bool Resolver::ResolveInternal() {
<< "At: " << node->source() << "\n"
<< "Content: " << builder_->str(node) << "\n"
<< "Pointer: " << node;
result = false;
}
}
return true;
return result;
}
const sem::Type* Resolver::Type(const ast::Type* ty) {
@ -360,6 +396,16 @@ const sem::Type* Resolver::Type(const ast::Type* ty) {
}
return nullptr;
}
if (auto* t = ty->As<ast::TypeName>()) {
auto it = named_types_.find(t->name());
if (it == named_types_.end()) {
diagnostics_.add_error(
"unknown type '" + builder_->Symbols().NameFor(t->name()) + "'",
t->source());
return nullptr;
}
return it->second;
}
TINT_UNREACHABLE(diagnostics_)
<< "Unhandled ast::Type: " << ty->TypeInfo().name;
return nullptr;
@ -392,21 +438,26 @@ bool Resolver::Type(const sem::Type* ty, const Source& source /* = {} */) {
Resolver::VariableInfo* Resolver::Variable(
ast::Variable* var,
const sem::Type* type /* = nullptr*/) {
const sem::Type* type, /* = nullptr */
std::string type_name /* = "" */) {
auto it = variable_to_info_.find(var);
if (it != variable_to_info_.end()) {
return it->second;
}
if (!type) {
type = var->declared_type();
if (type == nullptr && var->type()) {
type = Type(var->type());
type_name = var->type()->FriendlyName(builder_->Symbols());
}
if (type == nullptr) {
return nullptr;
}
auto type_name = type->FriendlyName(builder_->Symbols());
auto* ctype = Canonical(type);
auto* info = variable_infos_.Create(var, ctype, type_name);
variable_to_info_.emplace(var, info);
// TODO(bclayton): Why is this here? Needed?
// Resolve variable's type
if (auto* arr = info->type->As<sem::ArrayType>()) {
if (!Array(arr, var->source())) {
@ -805,12 +856,12 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
if (auto* struct_ty = Canonical(ty)->As<sem::StructType>()) {
// Validate the decorations for each struct members, and also check for
// invalid member types.
for (auto* member : struct_ty->impl()->members()) {
auto* member_ty = Canonical(member->type());
for (auto* member : Structure(struct_ty)->members) {
auto* member_ty = Canonical(member->Type());
if (member_ty->Is<sem::StructType>()) {
diagnostics_.add_error(
"entry point IO types cannot contain nested structures",
member->source());
member->Declaration()->source());
diagnostics_.add_note("while analysing entry point " +
builder_->Symbols().NameFor(func->symbol()),
func->source());
@ -819,7 +870,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
if (arr->IsRuntimeArray()) {
diagnostics_.add_error(
"entry point IO types cannot contain runtime sized arrays",
member->source());
member->Declaration()->source());
diagnostics_.add_note(
"while analysing entry point " +
builder_->Symbols().NameFor(func->symbol()),
@ -828,9 +879,9 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
}
}
if (!validate_entry_point_decorations_inner(member->decorations(),
member_ty, member->source(),
param_or_ret, true)) {
if (!validate_entry_point_decorations_inner(
member->Declaration()->decorations(), member_ty,
member->Declaration()->source(), param_or_ret, true)) {
diagnostics_.add_note("while analysing entry point " +
builder_->Symbols().NameFor(func->symbol()),
func->source());
@ -842,10 +893,10 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
return true;
};
for (auto* param : func->params()) {
for (auto* param : info->parameters) {
if (!validate_entry_point_decorations(
param->decorations(), param->declared_type(), param->source(),
ParamOrRetType::kParameter)) {
param->declaration->decorations(), param->type,
param->declaration->source(), ParamOrRetType::kParameter)) {
return false;
}
}
@ -943,19 +994,18 @@ bool Resolver::Function(ast::Function* func) {
}
}
if (func->return_type().ast || func->return_type().sem) {
info->return_type = func->return_type();
if (!info->return_type) {
info->return_type = Type(func->return_type().ast);
}
if (auto* ty = func->return_type()) {
info->return_type = Type(ty);
info->return_type_name = ty->FriendlyName(builder_->Symbols());
if (!info->return_type) {
return false;
}
} else {
info->return_type = builder_->create<sem::Void>();
info->return_type_name =
info->return_type->FriendlyName(builder_->Symbols());
}
info->return_type_name = info->return_type->FriendlyName(builder_->Symbols());
info->return_type = Canonical(info->return_type);
if (auto* str = info->return_type->As<sem::StructType>()) {
@ -1374,17 +1424,16 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
SetType(expr, type_ctor->type());
const sem::Type* type = TypeOf(expr);
// Now that the argument types have been determined, make sure that they
// obey the constructor type rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
if (auto* vec_type = type_ctor->type()->As<sem::Vector>()) {
return ValidateVectorConstructor(type_ctor, vec_type,
type_ctor->values());
if (auto* vec_type = type->As<sem::Vector>()) {
return ValidateVectorConstructor(type_ctor, vec_type);
}
if (auto* mat_type = type_ctor->type()->As<sem::Matrix>()) {
auto mat_typename = TypeNameOf(type_ctor);
return ValidateMatrixConstructor(type_ctor, mat_type,
type_ctor->values());
if (auto* mat_type = type->As<sem::Matrix>()) {
return ValidateMatrixConstructor(type_ctor, mat_type);
}
// TODO(crbug.com/tint/634): Validate array constructor
} else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
@ -1398,8 +1447,8 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
bool Resolver::ValidateVectorConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type,
const ast::ExpressionList& values) {
const sem::Vector* vec_type) {
auto& values = ctor->values();
auto* elem_type = vec_type->type()->UnwrapAll();
size_t value_cardinality_sum = 0;
for (auto* value : values) {
@ -1467,8 +1516,8 @@ bool Resolver::ValidateVectorConstructor(
bool Resolver::ValidateMatrixConstructor(
const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type,
const ast::ExpressionList& values) {
const sem::Matrix* matrix_type) {
auto& values = ctor->values();
// Zero Value expression
if (values.empty()) {
return true;
@ -1600,7 +1649,7 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
const sem::StructMember* member = nullptr;
for (auto* m : str->members) {
if (m->Declaration()->symbol() == symbol) {
ret = m->Declaration()->type();
ret = m->Type();
member = m;
break;
}
@ -1961,7 +2010,16 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
ast::Variable* var = stmt->variable();
Mark(var);
const sem::Type* type = var->declared_type();
// If the variable has a declared type, resolve it.
std::string type_name;
const sem::Type* type = nullptr;
if (auto* ast_ty = var->type()) {
type_name = ast_ty->FriendlyName(builder_->Symbols());
type = Type(ast_ty);
if (!type) {
return false;
}
}
bool is_global = false;
if (variable_stack_.get(var->symbol(), nullptr, &is_global)) {
@ -1982,14 +2040,15 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
// If the variable has no type, infer it from the rhs
if (type == nullptr) {
type_name = TypeNameOf(ctor);
type = rhs_type->UnwrapPtrIfNeeded();
}
if (!IsValidAssignment(type, rhs_type)) {
diagnostics_.add_error(
"variable of type '" + type->FriendlyName(builder_->Symbols()) +
"variable of type '" + type_name +
"' cannot be initialized with a value of type '" +
rhs_type->FriendlyName(builder_->Symbols()) + "'",
TypeNameOf(ctor) + "'",
stmt->source());
return false;
}
@ -2000,7 +2059,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
Mark(deco);
}
auto* info = Variable(var, type);
auto* info = Variable(var, type, type_name);
if (!info) {
return false;
}
@ -2071,13 +2130,19 @@ const sem::Type* Resolver::TypeOf(const ast::Literal* lit) {
return nullptr;
}
void Resolver::SetType(ast::Expression* expr, const sem::Type* type) {
SetType(expr, type, type->FriendlyName(builder_->Symbols()));
void Resolver::SetType(ast::Expression* expr, typ::Type type) {
SetType(expr, type,
type.sem ? type.sem->FriendlyName(builder_->Symbols())
: type.ast->FriendlyName(builder_->Symbols()));
}
void Resolver::SetType(ast::Expression* expr,
const sem::Type* type,
typ::Type type,
const std::string& type_name) {
if (!type.sem) {
type.sem = Type(type.ast);
TINT_ASSERT(type.sem);
}
if (expr_info_.count(expr)) {
TINT_ICE(builder_->Diagnostics())
<< "SetType() called twice for the same expression";
@ -2195,7 +2260,7 @@ void Resolver::CreateSemanticNodes() const {
}
}
bool Resolver::DefaultAlignAndSize(sem::Type* ty,
bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size,
const Source& source) {
@ -2363,24 +2428,24 @@ bool Resolver::ValidateArrayStrideDecoration(const ast::StrideDecoration* deco,
return true;
}
bool Resolver::ValidateStructure(const sem::StructType* st) {
for (auto* member : st->impl()->members()) {
if (auto* r = member->type()->UnwrapAll()->As<sem::ArrayType>()) {
bool Resolver::ValidateStructure(const StructInfo* st) {
for (auto* member : st->members) {
if (auto* r = member->Type()->UnwrapAll()->As<sem::ArrayType>()) {
if (r->IsRuntimeArray()) {
if (member != st->impl()->members().back()) {
if (member != st->members.back()) {
diagnostics_.add_error(
"v-0015",
"runtime arrays may only appear as the last member of a struct",
member->source());
member->Declaration()->source());
return false;
}
if (!st->IsBlockDecorated()) {
if (!st->type->impl()->IsBlockDecorated()) {
diagnostics_.add_error(
"v-0015",
"a struct containing a runtime-sized array "
"requires the [[block]] attribute: '" +
builder_->Symbols().NameFor(st->impl()->name()) + "'",
member->source());
builder_->Symbols().NameFor(st->type->impl()->name()) + "'",
member->Declaration()->source());
return false;
}
@ -2394,7 +2459,7 @@ bool Resolver::ValidateStructure(const sem::StructType* st) {
}
}
for (auto* deco : member->decorations()) {
for (auto* deco : member->Declaration()->decorations()) {
if (!(deco->Is<ast::BuiltinDecoration>() ||
deco->Is<ast::LocationDecoration>() ||
deco->Is<ast::StructMemberOffsetDecoration>() ||
@ -2407,7 +2472,7 @@ bool Resolver::ValidateStructure(const sem::StructType* st) {
}
}
for (auto* deco : st->impl()->decorations()) {
for (auto* deco : st->type->impl()->decorations()) {
if (!(deco->Is<ast::StructBlockDecoration>())) {
diagnostics_.add_error("decoration is not valid for struct declarations",
deco->source());
@ -2425,15 +2490,10 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
return info_it->second;
}
Mark(str->impl());
for (auto* deco : str->impl()->decorations()) {
Mark(deco);
}
if (!ValidateStructure(str)) {
return nullptr;
}
sem::StructMemberList sem_members;
sem_members.reserve(str->impl()->members().size());
@ -2454,12 +2514,16 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
for (auto* member : str->impl()->members()) {
Mark(member);
auto type = member->type();
// Resolve member type
auto* type = Type(member->type());
if (!type) {
return nullptr;
}
// First check the member type is legal
// Validate member type
if (!IsStorable(type)) {
builder_->Diagnostics().add_error(
std::string(type->FriendlyName(builder_->Symbols())) +
type->FriendlyName(builder_->Symbols()) +
" cannot be used as the type of a structure member");
return nullptr;
}
@ -2518,8 +2582,8 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
offset = utils::RoundUp(align, offset);
auto* sem_member =
builder_->create<sem::StructMember>(member, type, offset, align, size);
auto* sem_member = builder_->create<sem::StructMember>(
member, const_cast<sem::Type*>(type), offset, align, size);
builder_->Sem().Add(member, sem_member);
sem_members.emplace_back(sem_member);
@ -2531,11 +2595,17 @@ Resolver::StructInfo* Resolver::Structure(const sem::StructType* str) {
struct_size = utils::RoundUp(struct_align, struct_size);
auto* info = struct_infos_.Create();
info->type = str;
info->members = std::move(sem_members);
info->align = struct_align;
info->size = struct_size;
info->size_no_padding = size_no_padding;
struct_info_.emplace(str, info);
if (!ValidateStructure(info)) {
return nullptr;
}
return info;
}
@ -2745,13 +2815,13 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
return true; // Already applied
}
info->storage_class_usage.emplace(sc);
for (auto* member : str->impl()->members()) {
if (!ApplyStorageClassUsageToType(sc, member->type(), usage)) {
for (auto* member : info->members) {
if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) {
std::stringstream err;
err << "while analysing structure member "
<< str->FriendlyName(builder_->Symbols()) << "."
<< builder_->Symbols().NameFor(member->symbol());
diagnostics_.add_note(err.str(), member->source());
<< builder_->Symbols().NameFor(member->Declaration()->symbol());
diagnostics_.add_note(err.str(), member->Declaration()->source());
return false;
}
}
@ -2798,6 +2868,11 @@ const sem::Type* Resolver::Canonical(const sem::Type* type) {
using Type = sem::Type;
using Vector = sem::Vector;
if (!type) {
TINT_ICE(diagnostics_) << "Canonical() called with nullptr";
return nullptr;
}
std::function<const Type*(const Type*)> make_canonical;
make_canonical = [&](const Type* t) -> const sem::Type* {
// Unwrap alias sequence

View File

@ -73,11 +73,11 @@ class Resolver {
/// @param type the given type
/// @returns true if the given type is storable
static bool IsStorable(const sem::Type* type);
bool IsStorable(const sem::Type* type);
/// @param type the given type
/// @returns true if the given type is host-shareable
static bool IsHostShareable(const sem::Type* type);
bool IsHostShareable(const sem::Type* type);
/// @param lhs the assignment store type (non-pointer)
/// @param rhs the assignment source type (non-pointer or pointer with
@ -148,6 +148,7 @@ class Resolver {
StructInfo();
~StructInfo();
sem::StructType const* type = nullptr;
std::vector<const sem::StructMember*> members;
uint32_t align = 0;
uint32_t size = 0;
@ -253,16 +254,14 @@ class Resolver {
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type,
const ast::ExpressionList& values);
const sem::Matrix* matrix_type);
bool ValidateParameter(const ast::Variable* param);
bool ValidateReturn(const ast::ReturnStatement* ret);
bool ValidateStructure(const sem::StructType* st);
bool ValidateStructure(const StructInfo* st);
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const ast::Variable* param);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type,
const ast::ExpressionList& values);
const sem::Vector* vec_type);
/// @returns the sem::Type for the ast::Type `ty`, building it if it
/// hasn't been constructed already. If an error is raised, nullptr is
@ -284,9 +283,12 @@ class Resolver {
/// @returns the VariableInfo for the variable `var`, building it if it hasn't
/// been constructed already. If an error is raised, nullptr is returned.
/// @param var the variable to create or return the `VariableInfo` for
/// @param type optional type of `var` to use instead of
/// `var->declared_type()`. For type inference.
VariableInfo* Variable(ast::Variable* var, const sem::Type* type = nullptr);
/// @param type optional type of `var` to use instead of `var->type()`.
/// @param type_name optional type name of `var` to use instead of
/// `var->type()->FriendlyName()`.
VariableInfo* Variable(ast::Variable* var,
const sem::Type* type = nullptr,
std::string type_name = "");
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
@ -304,7 +306,7 @@ class Resolver {
/// @param size the output default size in bytes for the type `ty`
/// @param source the Source of the variable declaration of type `ty`
/// @returns true on success, false on error
bool DefaultAlignAndSize(sem::Type* ty,
bool DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size,
const Source& source);
@ -325,7 +327,7 @@ class Resolver {
/// assigns this semantic node to the expression `expr`.
/// @param expr the expression
/// @param type the resolved type
void SetType(ast::Expression* expr, const sem::Type* type);
void SetType(ast::Expression* expr, typ::Type type);
/// Creates a sem::Expression node with the resolved type `type`, the declared
/// type name `type_name` and assigns this semantic node to the expression
@ -334,7 +336,7 @@ class Resolver {
/// @param type the resolved type
/// @param type_name the declared type name
void SetType(ast::Expression* expr,
const sem::Type* type,
typ::Type type,
const std::string& type_name);
/// Constructs a new BlockInfo with the given type and with #current_block_ as
@ -369,6 +371,7 @@ class Resolver {
std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<const sem::StructType*, StructInfo*> struct_info_;
std::unordered_map<const sem::Type*, const sem::Type*> type_to_canonical_;
std::unordered_map<Symbol, const sem::Type*> named_types_;
std::unordered_set<const ast::Node*> marked_;
FunctionInfo* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;

View File

@ -1024,15 +1024,15 @@ namespace ExprBinaryTest {
struct Params {
ast::BinaryOp op;
create_type_func_ptr create_lhs_type;
create_type_func_ptr create_rhs_type;
create_type_func_ptr create_result_type;
create_ast_type_func_ptr create_lhs_type;
create_ast_type_func_ptr create_rhs_type;
create_sem_type_func_ptr create_result_type;
};
static constexpr create_type_func_ptr all_create_type_funcs[] = {
ty_bool_, ty_u32, ty_i32, ty_f32,
ty_vec3<bool>, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<f32>,
ty_mat3x3<i32>, ty_mat3x3<u32>, ty_mat3x3<f32>};
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
ast_bool, ast_u32, ast_i32, ast_f32,
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>};
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
// matNxN, we only test N=3.
@ -1041,156 +1041,163 @@ static constexpr Params all_valid_cases[] = {
// https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
// Binary logical expressions
Params{Op::kLogicalAnd, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kLogicalOr, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kLogicalAnd, ast_bool, ast_bool, sem_bool},
Params{Op::kLogicalOr, ast_bool, ast_bool, sem_bool},
Params{Op::kAnd, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kOr, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kAnd, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kOr, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kAnd, ast_bool, ast_bool, sem_bool},
Params{Op::kOr, ast_bool, ast_bool, sem_bool},
Params{Op::kAnd, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
Params{Op::kOr, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
// Arithmetic expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
// Binary arithmetic expressions over scalars
Params{Op::kAdd, ty_i32, ty_i32, ty_i32},
Params{Op::kSubtract, ty_i32, ty_i32, ty_i32},
Params{Op::kMultiply, ty_i32, ty_i32, ty_i32},
Params{Op::kDivide, ty_i32, ty_i32, ty_i32},
Params{Op::kModulo, ty_i32, ty_i32, ty_i32},
Params{Op::kAdd, ast_i32, ast_i32, sem_i32},
Params{Op::kSubtract, ast_i32, ast_i32, sem_i32},
Params{Op::kMultiply, ast_i32, ast_i32, sem_i32},
Params{Op::kDivide, ast_i32, ast_i32, sem_i32},
Params{Op::kModulo, ast_i32, ast_i32, sem_i32},
Params{Op::kAdd, ty_u32, ty_u32, ty_u32},
Params{Op::kSubtract, ty_u32, ty_u32, ty_u32},
Params{Op::kMultiply, ty_u32, ty_u32, ty_u32},
Params{Op::kDivide, ty_u32, ty_u32, ty_u32},
Params{Op::kModulo, ty_u32, ty_u32, ty_u32},
Params{Op::kAdd, ast_u32, ast_u32, sem_u32},
Params{Op::kSubtract, ast_u32, ast_u32, sem_u32},
Params{Op::kMultiply, ast_u32, ast_u32, sem_u32},
Params{Op::kDivide, ast_u32, ast_u32, sem_u32},
Params{Op::kModulo, ast_u32, ast_u32, sem_u32},
Params{Op::kAdd, ty_f32, ty_f32, ty_f32},
Params{Op::kSubtract, ty_f32, ty_f32, ty_f32},
Params{Op::kMultiply, ty_f32, ty_f32, ty_f32},
Params{Op::kDivide, ty_f32, ty_f32, ty_f32},
Params{Op::kModulo, ty_f32, ty_f32, ty_f32},
Params{Op::kAdd, ast_f32, ast_f32, sem_f32},
Params{Op::kSubtract, ast_f32, ast_f32, sem_f32},
Params{Op::kMultiply, ast_f32, ast_f32, sem_f32},
Params{Op::kDivide, ast_f32, ast_f32, sem_f32},
Params{Op::kModulo, ast_f32, ast_f32, sem_f32},
// Binary arithmetic expressions over vectors
Params{Op::kAdd, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kSubtract, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kMultiply, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kDivide, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kModulo, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kAdd, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kSubtract, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kMultiply, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kDivide, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kModulo, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kAdd, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kSubtract, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kMultiply, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kDivide, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kModulo, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kAdd, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kSubtract, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kMultiply, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kDivide, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kModulo, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kAdd, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kSubtract, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kDivide, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kModulo, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kAdd, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kSubtract, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kDivide, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kModulo, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
// Binary arithmetic expressions with mixed scalar, vector, and matrix
// operands
Params{Op::kMultiply, ty_vec3<f32>, ty_f32, ty_vec3<f32>},
Params{Op::kMultiply, ty_f32, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_f32, ty_mat3x3<f32>},
Params{Op::kMultiply, ty_f32, ty_mat3x3<f32>, ty_mat3x3<f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ty_vec3<f32>, ty_mat3x3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_mat3x3<f32>, ty_mat3x3<f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>},
// Comparison expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
// Comparisons over scalars
Params{Op::kEqual, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kNotEqual, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kEqual, ast_bool, ast_bool, sem_bool},
Params{Op::kNotEqual, ast_bool, ast_bool, sem_bool},
Params{Op::kEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kNotEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kLessThan, ty_i32, ty_i32, ty_bool_},
Params{Op::kLessThanEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kGreaterThan, ty_i32, ty_i32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kNotEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kLessThan, ast_i32, ast_i32, sem_bool},
Params{Op::kLessThanEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kGreaterThan, ast_i32, ast_i32, sem_bool},
Params{Op::kGreaterThanEqual, ast_i32, ast_i32, sem_bool},
Params{Op::kEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kNotEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kLessThan, ty_u32, ty_u32, ty_bool_},
Params{Op::kLessThanEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kGreaterThan, ty_u32, ty_u32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kNotEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kLessThan, ast_u32, ast_u32, sem_bool},
Params{Op::kLessThanEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kGreaterThan, ast_u32, ast_u32, sem_bool},
Params{Op::kGreaterThanEqual, ast_u32, ast_u32, sem_bool},
Params{Op::kEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kNotEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kLessThan, ty_f32, ty_f32, ty_bool_},
Params{Op::kLessThanEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kGreaterThan, ty_f32, ty_f32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kEqual, ast_f32, ast_f32, sem_bool},
Params{Op::kNotEqual, ast_f32, ast_f32, sem_bool},
Params{Op::kLessThan, ast_f32, ast_f32, sem_bool},
Params{Op::kLessThanEqual, ast_f32, ast_f32, sem_bool},
Params{Op::kGreaterThan, ast_f32, ast_f32, sem_bool},
Params{Op::kGreaterThanEqual, ast_f32, ast_f32, sem_bool},
// Comparisons over vectors
Params{Op::kEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<bool>, ast_vec3<bool>, sem_vec3<sem_bool>},
Params{Op::kEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kLessThan, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kLessThanEqual, ast_vec3<i32>, ast_vec3<i32>,
sem_vec3<sem_bool>},
Params{Op::kGreaterThan, ast_vec3<i32>, ast_vec3<i32>, sem_vec3<sem_bool>},
Params{Op::kGreaterThanEqual, ast_vec3<i32>, ast_vec3<i32>,
sem_vec3<sem_bool>},
Params{Op::kEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kLessThan, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kLessThanEqual, ast_vec3<u32>, ast_vec3<u32>,
sem_vec3<sem_bool>},
Params{Op::kGreaterThan, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_bool>},
Params{Op::kGreaterThanEqual, ast_vec3<u32>, ast_vec3<u32>,
sem_vec3<sem_bool>},
Params{Op::kEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kEqual, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kNotEqual, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kLessThan, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kLessThanEqual, ast_vec3<f32>, ast_vec3<f32>,
sem_vec3<sem_bool>},
Params{Op::kGreaterThan, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_bool>},
Params{Op::kGreaterThanEqual, ast_vec3<f32>, ast_vec3<f32>,
sem_vec3<sem_bool>},
// Bit expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#bit-expr
// Binary bitwise operations
Params{Op::kOr, ty_i32, ty_i32, ty_i32},
Params{Op::kAnd, ty_i32, ty_i32, ty_i32},
Params{Op::kXor, ty_i32, ty_i32, ty_i32},
Params{Op::kOr, ast_i32, ast_i32, sem_i32},
Params{Op::kAnd, ast_i32, ast_i32, sem_i32},
Params{Op::kXor, ast_i32, ast_i32, sem_i32},
Params{Op::kOr, ty_u32, ty_u32, ty_u32},
Params{Op::kAnd, ty_u32, ty_u32, ty_u32},
Params{Op::kXor, ty_u32, ty_u32, ty_u32},
Params{Op::kOr, ast_u32, ast_u32, sem_u32},
Params{Op::kAnd, ast_u32, ast_u32, sem_u32},
Params{Op::kXor, ast_u32, ast_u32, sem_u32},
// Bit shift expressions
Params{Op::kShiftLeft, ty_i32, ty_u32, ty_i32},
Params{Op::kShiftLeft, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
Params{Op::kShiftLeft, ast_i32, ast_u32, sem_i32},
Params{Op::kShiftLeft, ast_vec3<i32>, ast_vec3<u32>, sem_vec3<sem_i32>},
Params{Op::kShiftLeft, ty_u32, ty_u32, ty_u32},
Params{Op::kShiftLeft, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kShiftLeft, ast_u32, ast_u32, sem_u32},
Params{Op::kShiftLeft, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kShiftRight, ty_i32, ty_u32, ty_i32},
Params{Op::kShiftRight, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
Params{Op::kShiftRight, ast_i32, ast_u32, sem_i32},
Params{Op::kShiftRight, ast_vec3<i32>, ast_vec3<u32>, sem_vec3<sem_i32>},
Params{Op::kShiftRight, ty_u32, ty_u32, ty_u32},
Params{Op::kShiftRight, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>}};
Params{Op::kShiftRight, ast_u32, ast_u32, sem_u32},
Params{Op::kShiftRight, ast_vec3<u32>, ast_vec3<u32>, sem_vec3<sem_u32>}};
using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
TEST_P(Expr_Binary_Test_Valid, All) {
auto& params = GetParam();
auto lhs_type = params.create_lhs_type(ty);
auto rhs_type = params.create_rhs_type(ty);
auto result_type = params.create_result_type(ty);
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = params.create_rhs_type(ty);
auto* result_type = params.create_result_type(ty);
std::stringstream ss;
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols());
ss << FriendlyName(lhs_type) << " " << params.op << " "
<< FriendlyName(rhs_type);
SCOPED_TRACE(ss.str());
Global("lhs", lhs_type, ast::StorageClass::kInput);
@ -1215,27 +1222,28 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
const Params& params = std::get<0>(GetParam());
BinaryExprSide side = std::get<1>(GetParam());
auto lhs_type = params.create_lhs_type(ty);
auto rhs_type = params.create_rhs_type(ty);
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = params.create_rhs_type(ty);
std::stringstream ss;
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols());
ss << FriendlyName(lhs_type) << " " << params.op << " "
<< FriendlyName(rhs_type);
// For vectors and matrices, wrap the sub type in an alias
auto make_alias = [this](sem::Type* type) -> sem::Type* {
sem::Type* result;
if (auto* v = type->As<sem::Vector>()) {
result = create<sem::Vector>(
create<sem::Alias>(Symbols().New(), v->type()), v->size());
} else if (auto* m = type->As<sem::Matrix>()) {
result =
create<sem::Matrix>(create<sem::Alias>(Symbols().New(), m->type()),
m->rows(), m->columns());
} else {
result = create<sem::Alias>(Symbols().New(), type);
auto make_alias = [this](ast::Type* type) -> ast::Type* {
if (auto* v = type->As<ast::Vector>()) {
auto alias = ty.alias(Symbols().New(), v->type());
AST().AddConstructedType(alias);
return ty.vec(alias, v->size());
}
return result;
if (auto* m = type->As<ast::Matrix>()) {
auto alias = ty.alias(Symbols().New(), m->type());
AST().AddConstructedType(alias);
return ty.mat(alias, m->columns(), m->rows());
}
auto alias = ty.alias(Symbols().New(), type);
AST().AddConstructedType(alias);
return ty.type_name(alias.ast->name());
};
// Wrap in alias
@ -1246,8 +1254,8 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
rhs_type = make_alias(rhs_type);
}
ss << ", After aliasing: " << lhs_type->FriendlyName(Symbols()) << " "
<< params.op << " " << rhs_type->FriendlyName(Symbols());
ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op
<< " " << FriendlyName(rhs_type);
SCOPED_TRACE(ss.str());
Global("lhs", lhs_type, ast::StorageClass::kInput);
@ -1261,7 +1269,7 @@ TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
ASSERT_NE(TypeOf(expr), nullptr);
// TODO(amaiorano): Bring this back once we have a way to get the canonical
// type
// auto* result_type = params.create_result_type(ty);
// auto* *result_type = params.create_result_type(ty);
// ASSERT_TRUE(TypeOf(expr) == result_type);
}
INSTANTIATE_TEST_SUITE_P(
@ -1273,10 +1281,10 @@ INSTANTIATE_TEST_SUITE_P(
BinaryExprSide::Both)));
using Expr_Binary_Test_Invalid =
ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
ResolverTestWithParam<std::tuple<Params, create_ast_type_func_ptr>>;
TEST_P(Expr_Binary_Test_Invalid, All) {
const Params& params = std::get<0>(GetParam());
const create_type_func_ptr& create_type_func = std::get<1>(GetParam());
auto& create_type_func = std::get<1>(GetParam());
// Currently, for most operations, for a given lhs type, there is exactly one
// rhs type allowed. The only exception is for multiplication, which allows
@ -1290,8 +1298,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
return;
}
auto lhs_type = params.create_lhs_type(ty);
auto rhs_type = create_type_func(ty);
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = create_type_func(ty);
// Skip exceptions: multiplication of f32, vecN<f32>, and matNxN<f32>
if (params.op == Op::kMultiply &&
@ -1301,8 +1309,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
}
std::stringstream ss;
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols());
ss << FriendlyName(lhs_type) << " " << params.op << " "
<< FriendlyName(rhs_type);
SCOPED_TRACE(ss.str());
Global("lhs", lhs_type, ast::StorageClass::kInput);
@ -1316,9 +1324,8 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op()) +
" " + FriendlyName(rhs_type));
}
INSTANTIATE_TEST_SUITE_P(
ResolverTest,
@ -1365,9 +1372,8 @@ TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op()) +
" " + FriendlyName(rhs_type));
}
}
auto all_dimension_values = testing::Values(2u, 3u, 4u);
@ -1405,9 +1411,8 @@ TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op()) +
" " + FriendlyName(rhs_type));
}
}
INSTANTIATE_TEST_SUITE_P(ResolverTest,

View File

@ -16,6 +16,7 @@
#define SRC_RESOLVER_RESOLVER_TEST_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include "gtest/gtest.h"
@ -95,6 +96,14 @@ class TestHelper : public ProgramBuilder {
return true;
}
/// @param type a type
/// @returns the name for `type` that closely resembles how it would be
/// declared in WGSL.
std::string FriendlyName(typ::Type type) {
return type.ast ? type.ast->FriendlyName(Symbols())
: type.sem->FriendlyName(Symbols());
}
private:
std::unique_ptr<Resolver> resolver_;
};
@ -105,94 +114,151 @@ template <typename T>
class ResolverTestWithParam : public TestHelper,
public testing::TestWithParam<T> {};
inline typ::Type ty_bool_(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_bool(const ProgramBuilder::TypesBuilder& ty) {
return ty.bool_();
}
inline typ::Type ty_i32(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_i32(const ProgramBuilder::TypesBuilder& ty) {
return ty.i32();
}
inline typ::Type ty_u32(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_u32(const ProgramBuilder::TypesBuilder& ty) {
return ty.u32();
}
inline typ::Type ty_f32(const ProgramBuilder::TypesBuilder& ty) {
inline ast::Type* ast_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.f32();
}
using create_type_func_ptr =
typ::Type (*)(const ProgramBuilder::TypesBuilder& ty);
using create_ast_type_func_ptr =
ast::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <typename T>
typ::Type ty_vec2(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec2<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_vec2(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec2(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec2(create_type(ty));
}
template <typename T>
typ::Type ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3(create_type(ty));
}
template <typename T>
typ::Type ty_vec4(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec4<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_vec4(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_vec4(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec4(create_type(ty));
}
template <typename T>
typ::Type ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x2<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x2(create_type(ty));
}
template <typename T>
typ::Type ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3(create_type(ty));
}
template <typename T>
typ::Type ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat4x4<T>();
}
template <create_type_func_ptr create_type>
typ::Type ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat4x4(create_type(ty));
}
template <create_type_func_ptr create_type>
typ::Type ty_alias(const ProgramBuilder::TypesBuilder& ty) {
auto type = create_type(ty);
return ty.alias("alias_" + type->type_name(), type);
template <create_ast_type_func_ptr create_type>
ast::Type* ast_alias(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
auto name = ty.builder->Symbols().Register("alias_" + type->type_name());
if (!ty.builder->AST().LookupType(name)) {
ty.builder->AST().AddConstructedType(ty.alias(name, type));
}
return ty.builder->create<ast::TypeName>(name);
}
template <create_type_func_ptr create_type>
typ::Type ty_access(const ProgramBuilder::TypesBuilder& ty) {
auto type = create_type(ty);
template <create_ast_type_func_ptr create_type>
ast::Type* ast_access(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.access(ast::AccessControl::kReadOnly, type);
}
inline sem::Type* sem_bool(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Bool>();
}
inline sem::Type* sem_i32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::I32>();
}
inline sem::Type* sem_u32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::U32>();
}
inline sem::Type* sem_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::F32>();
}
using create_sem_type_func_ptr =
sem::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <create_sem_type_func_ptr create_type>
sem::Type* sem_vec2(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Vector>(create_type(ty), 2);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Vector>(create_type(ty), 3);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_vec4(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Vector>(create_type(ty), 4);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 2, 2);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 3, 3);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 4, 4);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_access(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.builder->create<sem::AccessControl>(ast::AccessControl::kReadOnly,
type);
}
} // namespace resolver
} // namespace tint

View File

@ -19,11 +19,11 @@ namespace resolver {
namespace {
/// @return the element type of `type` for vec and mat, otherwise `type` itself
sem::Type* ElementTypeOf(sem::Type* type) {
if (auto* v = type->As<sem::Vector>()) {
ast::Type* ElementTypeOf(ast::Type* type) {
if (auto* v = type->As<ast::Vector>()) {
return v->type();
}
if (auto* m = type->As<sem::Matrix>()) {
if (auto* m = type->As<ast::Matrix>()) {
return m->type();
}
return type;
@ -34,7 +34,8 @@ class ResolverTypeConstructorValidationTest : public resolver::TestHelper,
namespace InferTypeTest {
struct Params {
create_type_func_ptr create_rhs_type;
create_ast_type_func_ptr create_rhs_ast_type;
create_sem_type_func_ptr create_rhs_sem_type;
};
// Helpers and typedefs
@ -66,7 +67,7 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
// }
auto& params = GetParam();
auto rhs_type = params.create_rhs_type(ty);
auto* rhs_type = params.create_rhs_ast_type(ty);
auto* constructor_expr = ConstructValueFilledWith(rhs_type, 0);
auto sc = ast::StorageClass::kFunction;
@ -77,30 +78,33 @@ TEST_P(InferTypeTest_FromConstructorExpression, All) {
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type->UnwrapAliasIfNeeded(), sc));
auto* got = TypeOf(a_ident);
auto* expected = ty.pointer(params.create_rhs_sem_type(ty), sc).sem;
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
static constexpr Params from_constructor_expression_cases[] = {
Params{ty_bool_},
Params{ty_i32},
Params{ty_u32},
Params{ty_f32},
Params{ty_vec3<i32>},
Params{ty_vec3<u32>},
Params{ty_vec3<f32>},
Params{ty_mat3x3<i32>},
Params{ty_mat3x3<u32>},
Params{ty_mat3x3<f32>},
Params{ty_alias<ty_bool_>},
Params{ty_alias<ty_i32>},
Params{ty_alias<ty_u32>},
Params{ty_alias<ty_f32>},
Params{ty_alias<ty_vec3<i32>>},
Params{ty_alias<ty_vec3<u32>>},
Params{ty_alias<ty_vec3<f32>>},
Params{ty_alias<ty_mat3x3<i32>>},
Params{ty_alias<ty_mat3x3<u32>>},
Params{ty_alias<ty_mat3x3<f32>>},
Params{ast_bool, sem_bool},
Params{ast_i32, sem_i32},
Params{ast_u32, sem_u32},
Params{ast_f32, sem_f32},
Params{ast_vec3<i32>, sem_vec3<sem_i32>},
Params{ast_vec3<u32>, sem_vec3<sem_u32>},
Params{ast_vec3<f32>, sem_vec3<sem_f32>},
Params{ast_mat3x3<i32>, sem_mat3x3<sem_i32>},
Params{ast_mat3x3<u32>, sem_mat3x3<sem_u32>},
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_bool>, sem_bool},
Params{ast_alias<ast_i32>, sem_i32},
Params{ast_alias<ast_u32>, sem_u32},
Params{ast_alias<ast_f32>, sem_f32},
Params{ast_alias<ast_vec3<i32>>, sem_vec3<sem_i32>},
Params{ast_alias<ast_vec3<u32>>, sem_vec3<sem_u32>},
Params{ast_alias<ast_vec3<f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_mat3x3<i32>>, sem_mat3x3<sem_i32>},
Params{ast_alias<ast_mat3x3<u32>>, sem_mat3x3<sem_u32>},
Params{ast_alias<ast_mat3x3<f32>>, sem_mat3x3<sem_f32>},
};
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
InferTypeTest_FromConstructorExpression,
@ -114,7 +118,7 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) {
// }
auto& params = GetParam();
auto rhs_type = params.create_rhs_type(ty);
auto* rhs_type = params.create_rhs_ast_type(ty);
auto* arith_lhs_expr = ConstructValueFilledWith(rhs_type, 2);
auto* arith_rhs_expr = ConstructValueFilledWith(ElementTypeOf(rhs_type), 3);
@ -128,11 +132,17 @@ TEST_P(InferTypeTest_FromArithmeticExpression, All) {
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc));
auto* got = TypeOf(a_ident);
auto* expected = ty.pointer(params.create_rhs_sem_type(ty), sc).sem;
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
static constexpr Params from_arithmetic_expression_cases[] = {
Params{ty_i32}, Params{ty_u32}, Params{ty_f32},
Params{ty_vec3<f32>}, Params{ty_mat3x3<f32>},
Params{ast_i32, sem_i32},
Params{ast_u32, sem_u32},
Params{ast_f32, sem_f32},
Params{ast_vec3<f32>, sem_vec3<sem_f32>},
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
// TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed
// Params{ty_alias<ty_i32>},
@ -159,43 +169,44 @@ TEST_P(InferTypeTest_FromCallExpression, All) {
// }
auto& params = GetParam();
auto rhs_type = params.create_rhs_type(ty);
Func("foo", {}, rhs_type, {Return(ConstructValueFilledWith(rhs_type, 0))},
Func("foo", {}, params.create_rhs_ast_type(ty),
{Return(ConstructValueFilledWith(params.create_rhs_ast_type(ty), 0))},
{});
auto* constructor_expr = Call(Expr("foo"));
auto sc = ast::StorageClass::kFunction;
auto* a = Var("a", nullptr, sc, constructor_expr);
auto* a = Var("a", nullptr, sc, Call(Expr("foo")));
// Self-assign 'a' to force the expression to be resolved so we can test its
// type below
auto* a_ident = Expr("a");
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type->UnwrapAliasIfNeeded(), sc));
auto* got = TypeOf(a_ident);
auto* expected = ty.pointer(params.create_rhs_sem_type(ty), sc).sem;
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
static constexpr Params from_call_expression_cases[] = {
Params{ty_bool_},
Params{ty_i32},
Params{ty_u32},
Params{ty_f32},
Params{ty_vec3<i32>},
Params{ty_vec3<u32>},
Params{ty_vec3<f32>},
Params{ty_mat3x3<i32>},
Params{ty_mat3x3<u32>},
Params{ty_mat3x3<f32>},
Params{ty_alias<ty_bool_>},
Params{ty_alias<ty_i32>},
Params{ty_alias<ty_u32>},
Params{ty_alias<ty_f32>},
Params{ty_alias<ty_vec3<i32>>},
Params{ty_alias<ty_vec3<u32>>},
Params{ty_alias<ty_vec3<f32>>},
Params{ty_alias<ty_mat3x3<i32>>},
Params{ty_alias<ty_mat3x3<u32>>},
Params{ty_alias<ty_mat3x3<f32>>},
Params{ast_bool, sem_bool},
Params{ast_i32, sem_i32},
Params{ast_u32, sem_u32},
Params{ast_f32, sem_f32},
Params{ast_vec3<i32>, sem_vec3<sem_i32>},
Params{ast_vec3<u32>, sem_vec3<sem_u32>},
Params{ast_vec3<f32>, sem_vec3<sem_f32>},
Params{ast_mat3x3<i32>, sem_mat3x3<sem_i32>},
Params{ast_mat3x3<u32>, sem_mat3x3<sem_u32>},
Params{ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_bool>, sem_bool},
Params{ast_alias<ast_i32>, sem_i32},
Params{ast_alias<ast_u32>, sem_u32},
Params{ast_alias<ast_f32>, sem_f32},
Params{ast_alias<ast_vec3<i32>>, sem_vec3<sem_i32>},
Params{ast_alias<ast_vec3<u32>>, sem_vec3<sem_u32>},
Params{ast_alias<ast_vec3<f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_mat3x3<i32>>, sem_mat3x3<sem_i32>},
Params{ast_alias<ast_mat3x3<u32>>, sem_mat3x3<sem_u32>},
Params{ast_alias<ast_mat3x3<f32>>, sem_mat3x3<sem_f32>},
};
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
InferTypeTest_FromCallExpression,

View File

@ -445,48 +445,57 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
namespace GetCanonicalTests {
struct Params {
create_type_func_ptr create_type;
create_type_func_ptr create_canonical_type;
create_ast_type_func_ptr create_ast_type;
create_sem_type_func_ptr create_sem_type;
};
static constexpr Params cases[] = {
Params{ty_bool_, ty_bool_},
Params{ty_alias<ty_bool_>, ty_bool_},
Params{ty_alias<ty_alias<ty_bool_>>, ty_bool_},
Params{ast_bool, sem_bool},
Params{ast_alias<ast_bool>, sem_bool},
Params{ast_alias<ast_alias<ast_bool>>, sem_bool},
Params{ty_vec3<ty_f32>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_f32>>>, ty_vec3<ty_f32>},
Params{ast_vec3<ast_f32>, sem_vec3<sem_f32>},
Params{ast_alias<ast_vec3<ast_f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_alias<ast_vec3<ast_f32>>>, sem_vec3<sem_f32>},
Params{ty_vec3<ty_alias<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_alias<ty_f32>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_f32>>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_alias<ty_f32>>>>>,
ty_vec3<ty_f32>},
Params{ast_vec3<ast_alias<ast_f32>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_vec3<ast_alias<ast_f32>>>, sem_vec3<sem_f32>},
Params{ast_alias<ast_alias<ast_vec3<ast_alias<ast_f32>>>>,
sem_vec3<sem_f32>},
Params{ast_alias<ast_alias<ast_vec3<ast_alias<ast_alias<ast_f32>>>>>,
sem_vec3<sem_f32>},
Params{ty_mat3x3<ty_alias<ty_f32>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_mat3x3<ty_alias<ty_f32>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_f32>>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_alias<ty_f32>>>>>,
ty_mat3x3<ty_f32>},
Params{ast_mat3x3<ast_alias<ast_f32>>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_mat3x3<ast_alias<ast_f32>>>, sem_mat3x3<sem_f32>},
Params{ast_alias<ast_alias<ast_mat3x3<ast_alias<ast_f32>>>>,
sem_mat3x3<sem_f32>},
Params{ast_alias<ast_alias<ast_mat3x3<ast_alias<ast_alias<ast_f32>>>>>,
sem_mat3x3<sem_f32>},
Params{ty_alias<ty_access<ty_alias<ty_bool_>>>, ty_access<ty_bool_>},
Params{ty_alias<ty_access<ty_alias<ty_vec3<ty_access<ty_f32>>>>>,
ty_access<ty_vec3<ty_access<ty_f32>>>},
Params{ty_alias<ty_access<ty_alias<ty_mat3x3<ty_access<ty_f32>>>>>,
ty_access<ty_mat3x3<ty_access<ty_f32>>>},
Params{ast_alias<ast_access<ast_alias<ast_bool>>>, sem_access<sem_bool>},
Params{ast_alias<ast_access<ast_alias<ast_vec3<ast_access<ast_f32>>>>>,
sem_access<sem_vec3<sem_access<sem_f32>>>},
Params{ast_alias<ast_access<ast_alias<ast_mat3x3<ast_access<ast_f32>>>>>,
sem_access<sem_mat3x3<sem_access<sem_f32>>>},
};
using CanonicalTest = ResolverTestWithParam<Params>;
TEST_P(CanonicalTest, All) {
auto& params = GetParam();
auto type = params.create_type(ty);
auto expected_canonical_type = params.create_canonical_type(ty);
auto* type = params.create_ast_type(ty);
auto* canonical_type = r()->Canonical(type);
auto* var = Var("v", type, ast::StorageClass::kFunction);
auto* expr = Expr("v");
WrapInFunction(var, expr);
EXPECT_EQ(canonical_type, expected_canonical_type);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(expr)->UnwrapPtrIfNeeded();
auto* expected = params.create_sem_type(ty);
EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
CanonicalTest,
@ -529,26 +538,26 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
testing::ValuesIn(dimension_cases));
struct TypeParams {
create_type_func_ptr type_func;
create_ast_type_func_ptr type_func;
bool is_valid;
};
static constexpr TypeParams type_cases[] = {
TypeParams{ty_bool_, false},
TypeParams{ty_i32, true},
TypeParams{ty_u32, true},
TypeParams{ty_f32, true},
TypeParams{ast_bool, false},
TypeParams{ast_i32, true},
TypeParams{ast_u32, true},
TypeParams{ast_f32, true},
TypeParams{ty_alias<ty_bool_>, false},
TypeParams{ty_alias<ty_i32>, true},
TypeParams{ty_alias<ty_u32>, true},
TypeParams{ty_alias<ty_f32>, true},
TypeParams{ast_alias<ast_bool>, false},
TypeParams{ast_alias<ast_i32>, true},
TypeParams{ast_alias<ast_u32>, true},
TypeParams{ast_alias<ast_f32>, true},
TypeParams{ty_vec3<ty_f32>, false},
TypeParams{ty_mat3x3<ty_f32>, false},
TypeParams{ast_vec3<ast_f32>, false},
TypeParams{ast_mat3x3<ast_f32>, false},
TypeParams{ty_alias<ty_vec3<ty_f32>>, false},
TypeParams{ty_alias<ty_mat3x3<ty_f32>>, false}};
TypeParams{ast_alias<ast_vec3<ast_f32>>, false},
TypeParams{ast_alias<ast_mat3x3<ast_f32>>, false}};
using MultisampledTextureTypeTest = ResolverTestWithParam<TypeParams>;
TEST_P(MultisampledTextureTypeTest, All) {

View File

@ -2041,8 +2041,9 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
TEST_F(ResolverValidationTest, Expr_MatrixConstructor_ArgumentTypeAlias_Error) {
auto alias = ty.alias("VectorUnsigned2", ty.vec2<u32>());
AST().AddConstructedType(alias);
auto* tc = mat2x2<f32>(create<ast::TypeConstructorExpression>(
Source{{12, 34}}, alias, ExprList()),
auto* tc = mat2x2<f32>(
create<ast::TypeConstructorExpression>(
Source{{12, 34}}, ty.MaybeCreateTypename(alias), ExprList()),
vec2<f32>());
WrapInFunction(tc);
@ -2062,7 +2063,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_alias, ExprList()));
Source{{12, i}}, ty.MaybeCreateTypename(vec_alias), ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,

View File

@ -28,12 +28,11 @@ namespace sem {
namespace {
ParameterList GetParameters(ast::Function* ast) {
ParameterList GetParameters(const std::vector<const Variable*>& params) {
ParameterList parameters;
parameters.reserve(ast->params().size());
for (auto* param : ast->params()) {
parameters.emplace_back(
Parameter{param->declared_type(), Parameter::Usage::kNone});
parameters.reserve(params.size());
for (auto* param : params) {
parameters.emplace_back(Parameter{param->Type(), Parameter::Usage::kNone});
}
return parameters;
}
@ -47,7 +46,7 @@ Function::Function(ast::Function* declaration,
std::vector<const Variable*> local_referenced_module_vars,
std::vector<const ast::ReturnStatement*> return_statements,
std::vector<Symbol> ancestor_entry_points)
: Base(return_type, GetParameters(declaration)),
: Base(return_type, GetParameters(parameters)),
declaration_(declaration),
parameters_(std::move(parameters)),
referenced_module_vars_(std::move(referenced_module_vars)),

View File

@ -30,10 +30,6 @@ Variable::Variable(const ast::Variable* declaration,
Variable::~Variable() = default;
sem::Type* Variable::DeclaredType() const {
return declaration_->declared_type();
}
VariableUser::VariableUser(ast::IdentifierExpression* declaration,
const sem::Type* type,
Statement* statement,

View File

@ -54,9 +54,6 @@ class Variable : public Castable<Variable, Node> {
/// @returns the canonical type for the variable
sem::Type* Type() const { return const_cast<sem::Type*>(type_); }
/// @returns the AST node's type. May be nullptr.
sem::Type* DeclaredType() const;
/// @returns the storage class for the variable
ast::StorageClass StorageClass() const { return storage_class_; }

View File

@ -65,13 +65,13 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) {
if (ac_it != remappings->access_controls.end()) {
ast::AccessControl::Access ac = ac_it->second;
auto* ty = in->Sem().Get(var)->Type();
sem::Type* inner_ty = nullptr;
ast::Type* inner_ty = nullptr;
if (auto* old_ac = ty->As<sem::AccessControl>()) {
inner_ty = ctx.Clone(old_ac->type());
inner_ty = CreateASTTypeFor(&ctx, old_ac->type());
} else {
inner_ty = ctx.Clone(ty);
inner_ty = CreateASTTypeFor(&ctx, ty);
}
auto* new_ty = ctx.dst->create<sem::AccessControl>(ac, inner_ty);
auto* new_ty = ctx.dst->create<ast::AccessControl>(ac, inner_ty);
auto* new_var = ctx.dst->create<ast::Variable>(
ctx.Clone(var->source()), ctx.Clone(var->symbol()),
var->declared_storage_class(), new_ty, var->is_const(),

View File

@ -81,6 +81,8 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
auto get_buffer_size_intrinsic = [&](sem::StructType* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = ctx.dst->Sym();
auto* buffer_typename =
ctx.dst->ty.type_name(ctx.Clone(buffer_type->impl()->name()));
auto* func = ctx.dst->create<ast::Function>(
name,
ast::VariableList{
@ -88,7 +90,7 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
// in order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
ctx.Clone(buffer_type), true, nullptr, ast::DecorationList{}),
buffer_typename, true, nullptr, ast::DecorationList{}),
ctx.dst->Param("result",
ctx.dst->ty.pointer(ctx.dst->ty.u32(),
ast::StorageClass::kFunction)),
@ -98,7 +100,8 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
},
ast::DecorationList{});
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type, func);
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type->impl(),
func);
return name;
});
};

View File

@ -21,6 +21,7 @@
#include "src/program_builder.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
namespace tint {
@ -65,11 +66,11 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Strip entry point IO decorations from struct declarations.
// TODO(jrprice): This code is duplicated with the SPIR-V transform.
for (auto ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<sem::StructType>()) {
for (auto* ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
// Build new list of struct members without entry point IO decorations.
ast::StructMemberList new_struct_members;
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : struct_ty->members()) {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
return deco
@ -81,49 +82,53 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
}
// Redeclare the struct.
auto new_struct_name = ctx.Clone(struct_ty->impl()->name());
auto new_struct_name = ctx.Clone(struct_ty->name());
auto* new_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
new_struct_name, new_struct_members,
ctx.Clone(struct_ty->impl()->decorations())));
ctx.dst->create<ast::Struct>(new_struct_name, new_struct_members,
ctx.Clone(struct_ty->decorations()));
ctx.Replace(struct_ty, new_struct);
}
}
for (auto* func : ctx.src->AST().Functions()) {
if (!func->IsEntryPoint()) {
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
ast::VariableList new_parameters;
if (!func->params().empty()) {
if (!func->Parameters().empty()) {
// Collect all parameters and build a list of new struct members.
auto new_struct_param_symbol = ctx.dst->Sym();
ast::StructMemberList new_struct_members;
for (auto* param : func->params()) {
auto param_name = ctx.Clone(param->symbol());
auto* param_ty = ctx.src->Sem().Get(param)->Type();
auto* param_declared_ty = ctx.src->Sem().Get(param)->DeclaredType();
for (auto* param : func->Parameters()) {
auto param_name = ctx.Clone(param->Declaration()->symbol());
auto* param_ty = param->Type();
auto* param_declared_ty = param->Declaration()->type();
std::function<ast::Expression*()> func_const_initializer;
if (auto* struct_ty = param_ty->As<sem::StructType>()) {
auto* str = ctx.src->Sem().Get(struct_ty);
// Pull out all struct members and build initializer list.
std::vector<Symbol> member_names;
for (auto* member : struct_ty->impl()->members()) {
if (member->type()->UnwrapAll()->Is<sem::StructType>()) {
for (auto* member : str->Members()) {
if (member->Type()->UnwrapAll()->Is<sem::StructType>()) {
TINT_ICE(ctx.dst->Diagnostics()) << "nested pipeline IO struct";
}
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
&ctx, member->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
auto member_name = ctx.Clone(member->symbol());
new_struct_members.push_back(ctx.dst->Member(
member_name, ctx.Clone(member->type()), new_decorations));
auto member_name = ctx.Clone(member->Declaration()->symbol());
auto* member_type = ctx.Clone(member->Declaration()->type());
new_struct_members.push_back(
ctx.dst->Member(member_name, member_type, new_decorations));
member_names.emplace_back(member_name);
}
@ -139,7 +144,8 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
};
} else {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, param->decorations(), [](const ast::Decoration* deco) {
&ctx, param->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
@ -151,7 +157,7 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
};
}
if (func->body()->empty()) {
if (func_ast->body()->empty()) {
// Don't generate a function-scope const if the function is empty.
continue;
}
@ -160,11 +166,12 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const(
param_name, ctx.Clone(param_declared_ty), func_const_initializer());
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.InsertBefore(func_ast->body()->statements(),
*func_ast->body()->begin(),
ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
for (auto* user : param->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(param_name));
}
@ -176,44 +183,49 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Create the new struct type.
auto in_struct_name = ctx.dst->Sym();
auto* in_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
in_struct_name, new_struct_members, ast::DecorationList{}));
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct);
auto* in_struct = ctx.dst->create<ast::Struct>(
in_struct_name, new_struct_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
in_struct);
// Create a new function parameter using this struct type.
auto* struct_param = ctx.dst->Param(new_struct_param_symbol, in_struct);
auto* struct_param = ctx.dst->Param(
new_struct_param_symbol, ctx.dst->ty.type_name(in_struct_name));
new_parameters.push_back(struct_param);
}
// Handle return type.
auto* ret_type = func->return_type()->UnwrapAliasIfNeeded();
sem::Type* new_ret_type;
auto* ret_type = func->ReturnType()->UnwrapAliasIfNeeded();
std::function<ast::Type*()> new_ret_type;
if (ret_type->Is<sem::Void>()) {
new_ret_type = ctx.dst->ty.void_();
new_ret_type = [&ctx] { return ctx.dst->ty.void_(); };
} else {
ast::StructMemberList new_struct_members;
if (auto* struct_ty = ret_type->As<sem::StructType>()) {
auto* str = ctx.src->Sem().Get(struct_ty);
// Rebuild struct with only the entry point IO attributes.
for (auto* member : struct_ty->impl()->members()) {
if (member->type()->UnwrapAll()->Is<sem::StructType>()) {
for (auto* member : str->Members()) {
if (member->Type()->UnwrapAll()->Is<sem::StructType>()) {
TINT_ICE(ctx.dst->Diagnostics()) << "nested pipeline IO struct";
}
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
&ctx, member->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
auto symbol = ctx.Clone(member->Declaration()->symbol());
auto* member_ty = ctx.Clone(member->Declaration()->type());
new_struct_members.push_back(
ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
ctx.dst->Member(symbol, member_ty, new_decorations));
}
} else {
auto* member_ty = ctx.Clone(func->Declaration()->return_type());
auto decos = ctx.Clone(func_ast->return_type_decorations());
new_struct_members.push_back(
ctx.dst->Member("value", ctx.Clone(ret_type),
ctx.Clone(func->return_type_decorations())));
ctx.dst->Member("value", member_ty, std::move(decos)));
}
// Sort struct members to satisfy HLSL interfacing matching rules.
@ -222,15 +234,16 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Create the new struct type.
auto out_struct_name = ctx.dst->Sym();
auto* out_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
out_struct_name, new_struct_members, ast::DecorationList{}));
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, out_struct);
new_ret_type = out_struct;
auto* out_struct = ctx.dst->create<ast::Struct>(
out_struct_name, new_struct_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
out_struct);
new_ret_type = [out_struct_name, &ctx] {
return ctx.dst->ty.type_name(out_struct_name);
};
// Replace all return statements.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
for (auto* ret : func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
// Reconstruct the return value using the newly created struct.
std::function<ast::Expression*()> new_ret_value = [&ctx, ret] {
@ -243,8 +256,9 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Create a const to hold the return value expression to avoid
// re-evaluating it multiple times.
auto temp = ctx.dst->Sym();
auto* temp_var = ctx.dst->Decl(
ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value()));
auto* ty = CreateASTTypeFor(&ctx, ret_type);
auto* temp_var =
ctx.dst->Decl(ctx.dst->Const(temp, ty, new_ret_value()));
ctx.InsertBefore(ret_sem->Block()->statements(), ret, temp_var);
new_ret_value = [&ctx, temp] { return ctx.dst->Expr(temp); };
}
@ -258,17 +272,17 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
}
auto* new_ret =
ctx.dst->Return(ctx.dst->Construct(new_ret_type, ret_values));
ctx.dst->Return(ctx.dst->Construct(new_ret_type(), ret_values));
ctx.Replace(ret, new_ret);
}
}
// Rewrite the function header with the new parameters.
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), new_parameters, new_ret_type,
ctx.Clone(func->body()), ctx.Clone(func->decorations()),
ast::DecorationList{});
ctx.Replace(func, new_func);
func_ast->source(), ctx.Clone(func_ast->symbol()), new_parameters,
new_ret_type(), ctx.Clone(func_ast->body()),
ctx.Clone(func_ast->decorations()), ast::DecorationList{});
ctx.Replace(func_ast, new_func);
}
ctx.Clone();

View File

@ -23,6 +23,7 @@
#include "src/ast/assignment_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type_name.h"
#include "src/program_builder.h"
#include "src/sem/access_control_type.h"
#include "src/sem/array.h"
@ -318,7 +319,9 @@ DecomposeStorageAccess::Intrinsic* IntrinsicStoreFor(ProgramBuilder* builder,
/// Inserts `node` before `insert_after` in the global declarations of
/// `ctx.dst`. If `insert_after` is nullptr, then `node` is inserted at the top
/// of the module.
void InsertGlobal(CloneContext& ctx, Cloneable* insert_after, Cloneable* node) {
void InsertGlobal(CloneContext& ctx,
const Cloneable* insert_after,
Cloneable* node) {
auto& globals = ctx.src->AST().GlobalDeclarations();
if (insert_after) {
ctx.InsertAfter(globals, insert_after, node);
@ -328,7 +331,7 @@ void InsertGlobal(CloneContext& ctx, Cloneable* insert_after, Cloneable* node) {
}
/// @returns the unwrapped, user-declared constructed type of ty.
sem::Type* ConstructedTypeOf(sem::Type* ty) {
ast::NamedType* ConstructedTypeOf(sem::Type* ty) {
while (true) {
if (auto* ptr = ty->As<sem::Pointer>()) {
ty = ptr->type();
@ -338,11 +341,8 @@ sem::Type* ConstructedTypeOf(sem::Type* ty) {
ty = access->type();
continue;
}
if (auto* alias = ty->As<sem::Alias>()) {
return alias;
}
if (auto* str = ty->As<sem::StructType>()) {
return str;
return str->impl();
}
// Not a constructed type
return nullptr;
@ -368,8 +368,10 @@ struct Store {
StorageBufferAccess target; // The target for the write
};
} // namespace
/// State holds the current transform state
struct State {
struct DecomposeStorageAccess::State {
/// Map of AST expression to storage buffer access
/// This map has entries added when encountered, and removed when outer
/// expressions chain the access.
@ -385,9 +387,12 @@ struct State {
/// List of storage buffer writes
std::vector<Store> stores;
/// AddAccesss() adds the `expr -> access` map item to #accesses, and `expr`
/// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
/// to #expression_order.
void AddAccesss(ast::Expression* expr, StorageBufferAccess&& access) {
/// @param expr the expression that performs the access
/// @param access the access
void AddAccess(ast::Expression* expr, StorageBufferAccess&& access) {
TINT_ASSERT(access.type);
accesses.emplace(expr, std::move(access));
expression_order.emplace_back(expr);
}
@ -395,6 +400,8 @@ struct State {
/// TakeAccess() removes the `node` item from #accesses (if it exists),
/// returning the StorageBufferAccess. If #accesses does not hold an item for
/// `node`, an invalid StorageBufferAccess is returned.
/// @param node the expression that performed an access
/// @return the StorageBufferAccess for the given expression
StorageBufferAccess TakeAccess(ast::Expression* node) {
auto lhs_it = accesses.find(node);
if (lhs_it == accesses.end()) {
@ -408,24 +415,31 @@ struct State {
/// LoadFunc() returns a symbol to an intrinsic function that loads an element
/// of type `el_ty` from a storage buffer of type `buf_ty`. The function has
/// the signature: `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @return the name of the function that performs the load
Symbol LoadFunc(CloneContext& ctx,
Cloneable* insert_after,
ast::NamedType* insert_after,
sem::Type* buf_ty,
sem::Type* el_ty) {
return utils::GetOrCreate(load_funcs, TypePair{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
ast::VariableList params = {
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
ctx.Clone(buf_ty), true, nullptr, ast::DecorationList{}),
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, buf_ast_ty,
true, nullptr, ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.Clone(el_ty), nullptr,
ctx.dst->Sym(), params, el_ast_ty, nullptr,
ast::DecorationList{intrinsic}, ast::DecorationList{});
} else {
ast::ExpressionList values;
@ -444,7 +458,7 @@ struct State {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
member->Declaration()->type()->UnwrapAll());
member->Type()->UnwrapAll());
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* arr_ty = el_ty->As<sem::ArrayType>()) {
@ -457,11 +471,12 @@ struct State {
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
}
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.Clone(el_ty),
ctx.dst->Sym(), params, el_ast_ty,
ctx.dst->Block(
ctx.dst->Return(ctx.dst->create<ast::TypeConstructorExpression>(
ctx.Clone(el_ty), values))),
CreateASTTypeFor(&ctx, el_ty), values))),
ast::DecorationList{}, ast::DecorationList{});
}
InsertGlobal(ctx, insert_after, func);
@ -472,19 +487,26 @@ struct State {
/// StoreFunc() returns a symbol to an intrinsic function that stores an
/// element of type `el_ty` to a storage buffer of type `buf_ty`. The function
/// has the signature: `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @return the name of the function that performs the store
Symbol StoreFunc(CloneContext& ctx,
Cloneable* insert_after,
ast::NamedType* insert_after,
sem::Type* buf_ty,
sem::Type* el_ty) {
return utils::GetOrCreate(store_funcs, TypePair{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
ast::VariableList params{
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
ctx.Clone(buf_ty), true, nullptr, ast::DecorationList{}),
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, buf_ast_ty,
true, nullptr, ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
ctx.dst->Param("value", ctx.Clone(el_ty)),
ctx.dst->Param("value", el_ast_ty),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, el_ty)) {
@ -512,9 +534,8 @@ struct State {
auto* offset = ctx.dst->Add("offset", member->Offset());
auto* access = ctx.dst->MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol()));
Symbol store =
StoreFunc(ctx, insert_after, buf_ty,
member->Declaration()->type()->UnwrapAll());
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapAll());
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
@ -541,8 +562,6 @@ struct State {
}
};
} // namespace
DecomposeStorageAccess::Intrinsic::Intrinsic(ProgramID program_id, Type ty)
: Base(program_id), type(ty) {}
DecomposeStorageAccess::Intrinsic::~Intrinsic() = default;
@ -630,7 +649,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage) {
// Variable to a storage buffer
state.AddAccesss(ident, {
state.AddAccess(ident, {
var,
ToOffset(0u),
var->Type()->UnwrapAll(),
@ -649,7 +668,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
auto* vec_ty = access.type->As<sem::Vector>();
auto offset =
Mul(ScalarSize(vec_ty->type()), swizzle->Indices()[0]);
state.AddAccesss(
state.AddAccess(
accessor, {
access.var,
Add(std::move(access.offset), std::move(offset)),
@ -663,11 +682,11 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
auto* member =
sem.Get(str_ty)->FindMember(accessor->member()->symbol());
auto offset = member->Offset();
state.AddAccesss(accessor,
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
member->Declaration()->type()->UnwrapAll(),
member->Type()->UnwrapAll(),
});
}
}
@ -680,7 +699,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
if (auto* arr_ty = access.type->As<sem::ArrayType>()) {
auto stride = sem.Get(arr_ty)->Stride();
auto offset = Mul(stride, accessor->idx_expr());
state.AddAccesss(accessor,
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
@ -690,7 +709,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
}
if (auto* vec_ty = access.type->As<sem::Vector>()) {
auto offset = Mul(ScalarSize(vec_ty->type()), accessor->idx_expr());
state.AddAccesss(accessor,
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
@ -702,7 +721,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
auto offset = Mul(MatrixColumnStride(mat_ty), accessor->idx_expr());
auto* vec_ty = ctx.dst->create<sem::Vector>(
ctx.Clone(mat_ty->type()->UnwrapAll()), mat_ty->rows());
state.AddAccesss(accessor,
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),

View File

@ -95,6 +95,8 @@ class DecomposeStorageAccess : public Transform {
/// @param data optional extra transform-specific data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override;
struct State;
};
} // namespace transform

View File

@ -100,7 +100,7 @@ void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const {
// Create a new symbol for the constant
auto dst_symbol = ctx.dst->Sym();
// Clone the type
auto* dst_ty = ctx.Clone(src_ty);
auto* dst_ty = ctx.Clone(src_init->type());
// Clone the initializer
auto* dst_init = ctx.Clone(src_init);
// Construct the constant that holds the hoisted initializer

View File

@ -69,7 +69,7 @@ Output SingleEntryPoint::Run(const Program* in, const DataMap& data) {
// Clone any module-scope variables, types, and functions that are statically
// referenced by the target entry point.
for (auto* decl : in->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<sem::Type>()) {
if (auto* ty = decl->As<ast::NamedType>()) {
// TODO(jrprice): Strip unused types.
out.AST().AddConstructedType(ctx.Clone(ty));
} else if (auto* var = decl->As<ast::Variable>()) {

View File

@ -23,6 +23,7 @@
#include "src/program_builder.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
namespace tint {
@ -110,11 +111,11 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
// ```
// Strip entry point IO decorations from struct declarations.
for (auto ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<sem::StructType>()) {
for (auto* ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
// Build new list of struct members without entry point IO decorations.
ast::StructMemberList new_struct_members;
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : struct_ty->members()) {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
return deco
@ -126,52 +127,52 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
}
// Redeclare the struct.
auto new_struct_name = ctx.Clone(struct_ty->impl()->name());
auto new_struct_name = ctx.Clone(struct_ty->name());
auto* new_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
new_struct_name, new_struct_members,
ctx.Clone(struct_ty->impl()->decorations())));
ctx.dst->create<ast::Struct>(new_struct_name, new_struct_members,
ctx.Clone(struct_ty->decorations()));
ctx.Replace(struct_ty, new_struct);
}
}
for (auto* func : ctx.src->AST().Functions()) {
if (!func->IsEntryPoint()) {
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
for (auto* param : func->params()) {
for (auto* param : func->Parameters()) {
Symbol new_var = HoistToInputVariables(
ctx, func, ctx.src->Sem().Get(param)->Type(),
ctx.src->Sem().Get(param)->DeclaredType(), param->decorations());
ctx, func_ast, param->Type(), param->Declaration()->type(),
param->Declaration()->decorations());
// Replace all uses of the function parameter with the new variable.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
for (auto* user : param->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(new_var));
}
}
if (!func->return_type()->Is<sem::Void>()) {
if (!func->ReturnType()->Is<sem::Void>()) {
ast::StatementList stores;
auto store_value_symbol = ctx.dst->Sym();
HoistToOutputVariables(
ctx, func, func->return_type(), func->return_type(),
func->return_type_decorations(), {}, store_value_symbol, stores);
ctx, func_ast, func->ReturnType(), func_ast->return_type(),
func_ast->return_type_decorations(), {}, store_value_symbol, stores);
// Create a function that writes a return value to all output variables.
auto* store_value =
ctx.dst->Param(store_value_symbol, ctx.Clone(func->return_type()));
auto* store_value = ctx.dst->Param(store_value_symbol,
ctx.Clone(func_ast->return_type()));
auto return_func_symbol = ctx.dst->Sym();
auto* return_func = ctx.dst->create<ast::Function>(
return_func_symbol, ast::VariableList{store_value},
ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
ast::DecorationList{}, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, return_func);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
return_func);
// Replace all return statements with calls to the output function.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
for (auto* ret : func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
ctx.InsertBefore(ret_sem->Block()->statements(), ret,
@ -181,11 +182,13 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
}
// Rewrite the function header to remove the parameters and return value.
auto name = ctx.Clone(func_ast->symbol());
auto* body = ctx.Clone(func_ast->body());
auto decos = ctx.Clone(func_ast->decorations());
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
ctx.dst->ty.void_(), ctx.Clone(func->body()),
ctx.Clone(func->decorations()), ast::DecorationList{});
ctx.Replace(func, new_func);
func_ast->source(), name, ast::VariableList{}, ctx.dst->ty.void_(),
body, decos, ast::DecorationList{});
ctx.Replace(func_ast, new_func);
}
}
@ -253,7 +256,7 @@ Symbol Spirv::HoistToInputVariables(
CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations) const {
if (!ty->Is<sem::StructType>()) {
// Base case: create a global variable and return.
@ -273,9 +276,10 @@ Symbol Spirv::HoistToInputVariables(
// Recurse into struct members and build the initializer list.
std::vector<Symbol> init_value_names;
auto* struct_ty = ty->As<sem::StructType>();
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : ctx.src->Sem().Get(struct_ty)->Members()) {
auto member_var = HoistToInputVariables(
ctx, func, member->type(), member->type(), member->decorations());
ctx, func, member->Type(), member->Declaration()->type(),
member->Declaration()->decorations());
init_value_names.emplace_back(member_var);
}
@ -302,7 +306,7 @@ Symbol Spirv::HoistToInputVariables(
void Spirv::HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,
@ -333,11 +337,12 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
// Recurse into struct members.
auto* struct_ty = ty->As<sem::StructType>();
for (auto* member : struct_ty->impl()->members()) {
member_accesses.push_back(ctx.Clone(member->symbol()));
HoistToOutputVariables(ctx, func, member->type(), member->type(),
member->decorations(), member_accesses, store_value,
stores);
for (auto* member : ctx.src->Sem().Get(struct_ty)->Members()) {
member_accesses.push_back(ctx.Clone(member->Declaration()->symbol()));
HoistToOutputVariables(ctx, func, member->Type(),
member->Declaration()->type(),
member->Declaration()->decorations(),
member_accesses, store_value, stores);
member_accesses.pop_back();
}
}

View File

@ -60,7 +60,7 @@ class Spirv : public Transform {
Symbol HoistToInputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations) const;
/// Recursively create module-scope output variables for `ty` and build a list
@ -74,7 +74,7 @@ class Spirv : public Transform {
void HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,

View File

@ -49,7 +49,7 @@ ast::Function* Transform::CloneWithStatementsAtStart(
auto source = ctx->Clone(in->source());
auto symbol = ctx->Clone(in->symbol());
auto params = ctx->Clone(in->params());
auto return_type = ctx->Clone(in->return_type());
auto* return_type = ctx->Clone(in->return_type());
auto* body = ctx->dst->create<ast::BlockStatement>(
ctx->Clone(in->body()->source()), statements);
auto decos = ctx->Clone(in->decorations());

View File

@ -184,7 +184,7 @@ struct State {
// identifier strings instead of pointers, so we don't need to update
// any other place in the AST.
auto name = ctx.Clone(v->symbol());
auto* replacement = ctx.dst->Var(name, ctx.Clone(v->declared_type()),
auto* replacement = ctx.dst->Var(name, ctx.Clone(v->type()),
ast::StorageClass::kPrivate);
location_to_expr[location] = [this, name]() {
return ctx.dst->Expr(name);
@ -212,9 +212,9 @@ struct State {
{
ctx.dst->create<ast::StructBlockDecoration>(),
});
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
auto access =
ctx.dst->ty.access(ast::AccessControl::kReadOnly, struct_type);
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type
ctx.dst->Global(
GetVertexBufferName(i), access, ast::StorageClass::kStorage, nullptr,
@ -369,7 +369,7 @@ struct State {
/// @param count how many elements the vector has
ast::Expression* AccessVec(uint32_t buffer,
uint32_t element_stride,
sem::Type* base_type,
ast::Type* base_type,
VertexFormat base_format,
uint32_t count) {
ast::ExpressionList expr_list;
@ -381,7 +381,7 @@ struct State {
}
return ctx.dst->create<ast::TypeConstructorExpression>(
ctx.dst->create<sem::Vector>(base_type, count), std::move(expr_list));
ctx.dst->create<ast::Vector>(base_type, count), std::move(expr_list));
}
/// Process a non-struct entry point parameter.
@ -394,7 +394,7 @@ struct State {
ast::GetDecoration<ast::LocationDecoration>(param->decorations())) {
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol());
auto* func_var_type = ctx.Clone(param->declared_type());
auto* func_var_type = ctx.Clone(param->type());
auto* func_var = ctx.dst->Var(func_var_sym, func_var_type,
ast::StorageClass::kFunction);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
@ -428,18 +428,16 @@ struct State {
/// instance_index builtins.
/// @param func the entry point function
/// @param param the parameter to process
void ProcessStructParameter(ast::Function* func, ast::Variable* param) {
auto* struct_ty = param->declared_type()->As<sem::StructType>();
if (!struct_ty) {
TINT_ICE(ctx.dst->Diagnostics()) << "Invalid struct parameter";
}
/// @param struct_ty the structure type
void ProcessStructParameter(ast::Function* func,
ast::Variable* param,
ast::Struct* struct_ty) {
auto param_sym = ctx.Clone(param->symbol());
// Process the struct members.
bool has_locations = false;
ast::StructMemberList members_to_clone;
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : struct_ty->members()) {
auto member_sym = ctx.Clone(member->symbol());
std::function<ast::Expression*()> member_expr = [this, param_sym,
member_sym]() {
@ -472,7 +470,7 @@ struct State {
}
// Create a function-scope variable to replace the parameter.
auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->declared_type()),
auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type()),
ast::StorageClass::kFunction);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->Decl(func_var));
@ -482,7 +480,7 @@ struct State {
ast::StructMemberList new_members;
for (auto* member : members_to_clone) {
auto member_sym = ctx.Clone(member->symbol());
auto member_type = ctx.Clone(member->type());
auto* member_type = ctx.Clone(member->type());
auto member_decos = ctx.Clone(member->decorations());
new_members.push_back(
ctx.dst->Member(member_sym, member_type, std::move(member_decos)));
@ -514,8 +512,8 @@ struct State {
// Process entry point parameters.
for (auto* param : func->params()) {
auto* sem = ctx.src->Sem().Get(param);
if (sem->Type()->Is<sem::StructType>()) {
ProcessStructParameter(func, param);
if (auto* str = sem->Type()->As<sem::StructType>()) {
ProcessStructParameter(func, param, str->impl());
} else {
ProcessNonStructParameter(func, param);
}
@ -553,7 +551,7 @@ struct State {
// Rewrite the function header with the new parameters.
auto func_sym = ctx.Clone(func->symbol());
auto ret_type = ctx.Clone(func->return_type());
auto* ret_type = ctx.Clone(func->return_type());
auto* body = ctx.Clone(func->body());
auto decos = ctx.Clone(func->decorations());
auto ret_decos = ctx.Clone(func->return_type_decorations());

View File

@ -23,9 +23,10 @@ namespace writer {
namespace {
ast::TypeConstructorExpression* AsVectorConstructor(ast::Expression* expr) {
ast::TypeConstructorExpression* AsVectorConstructor(ProgramBuilder* b,
ast::Expression* expr) {
if (auto* constructor = expr->As<ast::TypeConstructorExpression>()) {
if (constructor->type()->Is<sem::Vector>()) {
if (b->TypeOf(constructor)->Is<sem::Vector>()) {
return constructor;
}
}
@ -38,42 +39,52 @@ ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b,
ast::Expression* vector,
ast::Expression* scalar) {
uint32_t packed_size;
sem::Type* packed_el_ty; // Currently must be f32.
sem::Type* packed_el_sem_ty;
auto* vector_sem = b->Sem().Get(vector);
auto* vector_ty = vector_sem->Type()->UnwrapPtrIfNeeded();
if (auto* vec = vector_ty->As<sem::Vector>()) {
packed_size = vec->size() + 1;
packed_el_ty = vec->type();
packed_el_sem_ty = vec->type();
} else {
packed_size = 2;
packed_el_ty = vector_ty;
packed_el_sem_ty = vector_ty;
}
ast::Type* packed_el_ty = nullptr;
if (packed_el_sem_ty->Is<sem::I32>()) {
packed_el_ty = b->create<ast::I32>();
} else if (packed_el_sem_ty->Is<sem::U32>()) {
packed_el_ty = b->create<ast::U32>();
} else if (packed_el_sem_ty->Is<sem::F32>()) {
packed_el_ty = b->create<ast::F32>();
}
auto* statement = vector_sem->Stmt();
auto* packed_ty = b->create<sem::Vector>(packed_el_ty, packed_size);
auto* packed_ty = b->create<ast::Vector>(packed_el_ty, packed_size);
auto* packed_sem_ty = b->create<sem::Vector>(packed_el_sem_ty, packed_size);
// If the coordinates are already passed in a vector constructor, extract
// the elements into the new vector instead of nesting a vector-in-vector.
ast::ExpressionList packed;
if (auto* vc = AsVectorConstructor(vector)) {
if (auto* vc = AsVectorConstructor(b, vector)) {
packed = vc->values();
} else {
packed.emplace_back(vector);
}
if (packed_el_ty != b->Sem().Get(scalar)->Type()->UnwrapPtrIfNeeded()) {
if (packed_el_sem_ty != b->TypeOf(scalar)->UnwrapPtrIfNeeded()) {
// Cast scalar to the vector element type
auto* scalar_cast = b->Construct(packed_el_ty, scalar);
b->Sem().Add(scalar_cast, b->create<sem::Expression>(
scalar_cast, packed_el_ty, statement));
scalar_cast, packed_el_sem_ty, statement));
packed.emplace_back(scalar_cast);
} else {
packed.emplace_back(scalar);
}
auto* constructor = b->Construct(packed_ty, std::move(packed));
b->Sem().Add(constructor,
b->create<sem::Expression>(constructor, packed_ty, statement));
b->Sem().Add(constructor, b->create<sem::Expression>(
constructor, packed_sem_ty, statement));
return constructor;
}

View File

@ -119,8 +119,11 @@ bool GeneratorImpl::Generate(std::ostream& out) {
register_global(global);
}
for (auto const ty : builder_.AST().ConstructedTypes()) {
if (!EmitConstructedType(out, ty)) {
for (auto* const ty : builder_.AST().ConstructedTypes()) {
if (ty->Is<ast::Alias>()) {
continue;
}
if (!EmitConstructedType(out, TypeOf(ty))) {
return false;
}
}
@ -2012,7 +2015,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
bool has_outdata = outdata != ep_sym_to_out_data_.end();
if (has_outdata) {
// TODO(crbug.com/tint/697): Remove this.
if (!func->return_type()->Is<sem::Void>()) {
if (!func->return_type()->Is<ast::Void>()) {
TINT_ICE(diagnostics_) << "Mixing module-scope variables and return "
"types for shader outputs";
}

View File

@ -381,7 +381,7 @@ class GeneratorImpl : public TextGenerator {
/// @returns the resolved type of the ast::Type `type`
/// @param type the type
const sem::Type* TypeOf(ast::Type* type) const {
const sem::Type* TypeOf(const ast::Type* type) const {
return builder_.TypeOf(type);
}

View File

@ -28,6 +28,7 @@
#include "src/ast/sint_literal.h"
#include "src/ast/uint_literal.h"
#include "src/ast/variable_decl_statement.h"
#include "src/ast/void.h"
#include "src/sem/access_control_type.h"
#include "src/sem/alias_type.h"
#include "src/sem/array.h"
@ -86,8 +87,8 @@ bool GeneratorImpl::Generate() {
global_variables_.set(global->symbol(), sem);
}
for (auto const ty : program_->AST().ConstructedTypes()) {
if (!EmitConstructedType(ty)) {
for (auto* const ty : program_->AST().ConstructedTypes()) {
if (!EmitConstructedType(TypeOf(ty))) {
return false;
}
}
@ -1400,7 +1401,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
bool has_out_data = out_data != ep_sym_to_out_data_.end();
if (has_out_data) {
// TODO(crbug.com/tint/697): Remove this.
if (!func->return_type()->Is<sem::Void>()) {
if (!func->return_type()->Is<ast::Void>()) {
TINT_ICE(diagnostics_) << "Mixing module-scope variables and return "
"types for shader outputs";
}

View File

@ -275,7 +275,7 @@ class GeneratorImpl : public TextGenerator {
/// @returns the resolved type of the ast::Type `type`
/// @param type the type
const sem::Type* TypeOf(ast::Type* type) const {
const sem::Type* TypeOf(const ast::Type* type) const {
return program_->TypeOf(type);
}

View File

@ -842,9 +842,8 @@ TEST_F(BuilderTest_Type, SampledTexture_Generate_CubeArray) {
}
TEST_F(BuilderTest_Type, StorageTexture_Generate_1d) {
auto* s = create<sem::StorageTexture>(
ast::TextureDimension::k1d, ast::ImageFormat::kR32Float,
sem::StorageTexture::SubtypeFor(ast::ImageFormat::kR32Float, Types()));
auto s = ty.storage_texture(ast::TextureDimension::k1d,
ast::ImageFormat::kR32Float);
auto ac = ty.access(ast::AccessControl::kReadOnly, s);
Global("test_var", ac, ast::StorageClass::kInput);
@ -859,9 +858,8 @@ TEST_F(BuilderTest_Type, StorageTexture_Generate_1d) {
}
TEST_F(BuilderTest_Type, StorageTexture_Generate_2d) {
auto* s = create<sem::StorageTexture>(
ast::TextureDimension::k2d, ast::ImageFormat::kR32Float,
sem::StorageTexture::SubtypeFor(ast::ImageFormat::kR32Float, Types()));
auto s = ty.storage_texture(ast::TextureDimension::k2d,
ast::ImageFormat::kR32Float);
auto ac = ty.access(ast::AccessControl::kReadOnly, s);
Global("test_var", ac, ast::StorageClass::kInput);
@ -876,9 +874,8 @@ TEST_F(BuilderTest_Type, StorageTexture_Generate_2d) {
}
TEST_F(BuilderTest_Type, StorageTexture_Generate_2dArray) {
auto* s = create<sem::StorageTexture>(
ast::TextureDimension::k2dArray, ast::ImageFormat::kR32Float,
sem::StorageTexture::SubtypeFor(ast::ImageFormat::kR32Float, Types()));
auto s = ty.storage_texture(ast::TextureDimension::k2dArray,
ast::ImageFormat::kR32Float);
auto ac = ty.access(ast::AccessControl::kReadOnly, s);
Global("test_var", ac, ast::StorageClass::kInput);
@ -893,9 +890,8 @@ TEST_F(BuilderTest_Type, StorageTexture_Generate_2dArray) {
}
TEST_F(BuilderTest_Type, StorageTexture_Generate_3d) {
auto* s = create<sem::StorageTexture>(
ast::TextureDimension::k3d, ast::ImageFormat::kR32Float,
sem::StorageTexture::SubtypeFor(ast::ImageFormat::kR32Float, Types()));
auto s = ty.storage_texture(ast::TextureDimension::k3d,
ast::ImageFormat::kR32Float);
auto ac = ty.access(ast::AccessControl::kReadOnly, s);
Global("test_var", ac, ast::StorageClass::kInput);
@ -911,9 +907,8 @@ TEST_F(BuilderTest_Type, StorageTexture_Generate_3d) {
TEST_F(BuilderTest_Type,
StorageTexture_Generate_SampledTypeFloat_Format_r32float) {
auto* s = create<sem::StorageTexture>(
ast::TextureDimension::k2d, ast::ImageFormat::kR32Float,
sem::StorageTexture::SubtypeFor(ast::ImageFormat::kR32Float, Types()));
auto s = ty.storage_texture(ast::TextureDimension::k2d,
ast::ImageFormat::kR32Float);
auto ac = ty.access(ast::AccessControl::kReadOnly, s);
Global("test_var", ac, ast::StorageClass::kInput);
@ -929,9 +924,8 @@ TEST_F(BuilderTest_Type,
TEST_F(BuilderTest_Type,
StorageTexture_Generate_SampledTypeSint_Format_r32sint) {
auto* s = create<sem::StorageTexture>(
ast::TextureDimension::k2d, ast::ImageFormat::kR32Sint,
sem::StorageTexture::SubtypeFor(ast::ImageFormat::kR32Sint, Types()));
auto s = ty.storage_texture(ast::TextureDimension::k2d,
ast::ImageFormat::kR32Sint);
auto ac = ty.access(ast::AccessControl::kReadOnly, s);
Global("test_var", ac, ast::StorageClass::kInput);
@ -947,9 +941,8 @@ TEST_F(BuilderTest_Type,
TEST_F(BuilderTest_Type,
StorageTexture_Generate_SampledTypeUint_Format_r32uint) {
auto* s = create<sem::StorageTexture>(
ast::TextureDimension::k2d, ast::ImageFormat::kR32Uint,
sem::StorageTexture::SubtypeFor(ast::ImageFormat::kR32Uint, Types()));
auto s = ty.storage_texture(ast::TextureDimension::k2d,
ast::ImageFormat::kR32Uint);
auto ac = ty.access(ast::AccessControl::kReadOnly, s);
Global("test_var", ac, ast::StorageClass::kInput);

View File

@ -83,10 +83,6 @@ bool GeneratorImpl::Generate(const ast::Function* entry) {
if (!EmitConstructedType(ty)) {
return false;
}
} else if (auto* sem_ty = decl->As<sem::Type>()) {
if (!EmitConstructedType(sem_ty)) {
return false;
}
} else if (auto* func = decl->As<ast::Function>()) {
if (entry && func != entry) {
// Skip functions that are not reachable by the target entry point.
@ -363,8 +359,7 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
out_ << ")";
if (!(Is<ast::Void>(func->return_type().ast) ||
Is<sem::Void>(func->return_type().sem)) ||
if (!func->return_type()->Is<ast::Void>() ||
!func->return_type_decorations().empty()) {
out_ << " -> ";
@ -776,9 +771,9 @@ bool GeneratorImpl::EmitVariable(ast::Variable* var) {
out_ << " " << program_->Symbols().NameFor(var->symbol());
if (var->type().ast || var->type().sem) {
if (auto* ty = var->type()) {
out_ << " : ";
if (!EmitType(var->type())) {
if (!EmitType(ty)) {
return false;
}
}

View File

@ -32,8 +32,8 @@ TEST_P(WgslUnaryOpTest, Emit) {
auto params = GetParam();
auto* type = (params.op == ast::UnaryOp::kNot)
? static_cast<sem::Type*>(ty.bool_())
: static_cast<sem::Type*>(ty.i32());
? static_cast<ast::Type*>(ty.bool_())
: static_cast<ast::Type*>(ty.i32());
Global("expr", type, ast::StorageClass::kPrivate);
auto* op = create<ast::UnaryOpExpression>(params.op, Expr("expr"));