From fe0910fa37760d4e7b53a13645daa0085bd9cab7 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Mon, 17 May 2021 15:51:47 +0000 Subject: [PATCH] ProgramBuilder: New helpers,change WrapInStatement Add AddressOf() and Deref() Add overloads of Expr() that take a source Change WrapInStatement() to create a `let`. Unlike `var`, `let` can be used to hold pointers. Bug: tint:727 Change-Id: Ib2cd7ab7a7056862e064943dea04387f7e466212 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51183 Commit-Queue: Ben Clayton Kokoro: Kokoro Reviewed-by: Antonio Maiorano --- src/program_builder.cc | 11 +- src/program_builder.h | 103 +++++++++++++++++- .../generator_impl_member_accessor_test.cc | 4 +- 3 files changed, 110 insertions(+), 8 deletions(-) diff --git a/src/program_builder.cc b/src/program_builder.cc index a59b03a0b9..6cf5eac812 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -20,6 +20,7 @@ #include "src/debug.h" #include "src/demangler.h" #include "src/sem/expression.h" +#include "src/sem/variable.h" namespace tint { @@ -90,6 +91,11 @@ sem::Type* ProgramBuilder::TypeOf(const ast::Expression* expr) const { return sem ? sem->Type() : nullptr; } +sem::Type* ProgramBuilder::TypeOf(const ast::Variable* var) const { + auto* sem = Sem().Get(var); + return sem ? sem->Type() : nullptr; +} + const sem::Type* ProgramBuilder::TypeOf(const ast::Type* type) const { return Sem().Get(type); } @@ -162,8 +168,11 @@ ast::Statement* ProgramBuilder::WrapInStatement(ast::Literal* lit) { } ast::Statement* ProgramBuilder::WrapInStatement(ast::Expression* expr) { + if (auto* ce = expr->As()) { + return create(ce); + } // Create a temporary variable of inferred type from expr. - return Decl(Var(symbols_.New(), nullptr, ast::StorageClass::kFunction, expr)); + return Decl(Const(symbols_.New(), nullptr, expr)); } ast::VariableDeclStatement* ProgramBuilder::WrapInStatement(ast::Variable* v) { diff --git a/src/program_builder.h b/src/program_builder.h index 42375a9293..74367bcbeb 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -55,6 +55,7 @@ #include "src/ast/type_name.h" #include "src/ast/u32.h" #include "src/ast/uint_literal.h" +#include "src/ast/unary_op_expression.h" #include "src/ast/variable_decl_statement.h" #include "src/ast/vector.h" #include "src/ast/void.h" @@ -920,10 +921,11 @@ class ProgramBuilder { /// @return nullptr ast::IdentifierExpression* Expr(std::nullptr_t) { return nullptr; } - /// @param name the identifier name - /// @return an ast::IdentifierExpression with the given name - ast::IdentifierExpression* Expr(const std::string& name) { - return create(Symbols().Register(name)); + /// @param source the source information + /// @param symbol the identifier symbol + /// @return an ast::IdentifierExpression with the given symbol + ast::IdentifierExpression* Expr(const Source& source, Symbol symbol) { + return create(source, symbol); } /// @param symbol the identifier symbol @@ -932,12 +934,33 @@ class ProgramBuilder { return create(symbol); } + /// @param source the source information + /// @param variable the AST variable + /// @return an ast::IdentifierExpression with the variable's symbol + ast::IdentifierExpression* Expr(const Source& source, + ast::Variable* variable) { + return create(source, variable->symbol()); + } + /// @param variable the AST variable /// @return an ast::IdentifierExpression with the variable's symbol ast::IdentifierExpression* Expr(ast::Variable* variable) { return create(variable->symbol()); } + /// @param source the source information + /// @param name the identifier name + /// @return an ast::IdentifierExpression with the given name + ast::IdentifierExpression* Expr(const Source& source, const char* name) { + return create(source, Symbols().Register(name)); + } + + /// @param name the identifier name + /// @return an ast::IdentifierExpression with the given name + ast::IdentifierExpression* Expr(const char* name) { + return create(Symbols().Register(name)); + } + /// @param source the source information /// @param name the identifier name /// @return an ast::IdentifierExpression with the given name @@ -948,28 +971,56 @@ class ProgramBuilder { /// @param name the identifier name /// @return an ast::IdentifierExpression with the given name - ast::IdentifierExpression* Expr(const char* name) { + ast::IdentifierExpression* Expr(const std::string& name) { return create(Symbols().Register(name)); } + /// @param source the source information + /// @param value the boolean value + /// @return a Scalar constructor for the given value + ast::ScalarConstructorExpression* Expr(const Source& source, bool value) { + return create(source, Literal(value)); + } + /// @param value the boolean value /// @return a Scalar constructor for the given value ast::ScalarConstructorExpression* Expr(bool value) { return create(Literal(value)); } + /// @param source the source information + /// @param value the float value + /// @return a Scalar constructor for the given value + ast::ScalarConstructorExpression* Expr(const Source& source, f32 value) { + return create(source, Literal(value)); + } + /// @param value the float value /// @return a Scalar constructor for the given value ast::ScalarConstructorExpression* Expr(f32 value) { return create(Literal(value)); } + /// @param source the source information + /// @param value the integer value + /// @return a Scalar constructor for the given value + ast::ScalarConstructorExpression* Expr(const Source& source, i32 value) { + return create(source, Literal(value)); + } + /// @param value the integer value /// @return a Scalar constructor for the given value ast::ScalarConstructorExpression* Expr(i32 value) { return create(Literal(value)); } + /// @param source the source information + /// @param value the unsigned int value + /// @return a Scalar constructor for the given value + ast::ScalarConstructorExpression* Expr(const Source& source, u32 value) { + return create(source, Literal(value)); + } + /// @param value the unsigned int value /// @return a Scalar constructor for the given value ast::ScalarConstructorExpression* Expr(u32 value) { @@ -1354,6 +1405,40 @@ class ProgramBuilder { return var; } + /// @param source the source information + /// @param expr the expression to take the address of + /// @return an ast::UnaryOpExpression that takes the address of `expr` + template + ast::UnaryOpExpression* AddressOf(const Source& source, EXPR&& expr) { + return create(source, ast::UnaryOp::kAddressOf, + Expr(std::forward(expr))); + } + + /// @param expr the expression to take the address of + /// @return an ast::UnaryOpExpression that takes the address of `expr` + template + ast::UnaryOpExpression* AddressOf(EXPR&& expr) { + return create(ast::UnaryOp::kAddressOf, + Expr(std::forward(expr))); + } + + /// @param source the source information + /// @param expr the expression to perform an indirection on + /// @return an ast::UnaryOpExpression that dereferences the pointer `expr` + template + ast::UnaryOpExpression* Deref(const Source& source, EXPR&& expr) { + return create(source, ast::UnaryOp::kIndirection, + Expr(std::forward(expr))); + } + + /// @param expr the expression to perform an indirection on + /// @return an ast::UnaryOpExpression that dereferences the pointer `expr` + template + ast::UnaryOpExpression* Deref(EXPR&& expr) { + return create(ast::UnaryOp::kIndirection, + Expr(std::forward(expr))); + } + /// @param func the function name /// @param args the function call arguments /// @returns a `ast::CallExpression` to the function `func`, with the @@ -1845,6 +1930,14 @@ class ProgramBuilder { /// expression has no resolved type. sem::Type* TypeOf(const ast::Expression* expr) const; + /// Helper for returning the resolved semantic type of the variable `var`. + /// @note As the Resolver is run when the Program is built, this will only be + /// useful for the Resolver itself and tests that use their own Resolver. + /// @param var the AST variable + /// @return the resolved semantic type for the variable, or nullptr if the + /// variable has no resolved type. + sem::Type* TypeOf(const ast::Variable* var) const; + /// Helper for returning the resolved semantic type of the AST type `type`. /// @note As the Resolver is run when the Program is built, this will only be /// useful for the Resolver itself and tests that use their own Resolver. diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc index 8f0d658eb5..423a0465d4 100644 --- a/src/writer/hlsl/generator_impl_member_accessor_test.cc +++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc @@ -128,7 +128,7 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor, EmitExpression_MemberAccessor) { Global("str", s, ast::StorageClass::kPrivate); auto* expr = MemberAccessor("str", "mem"); - WrapInFunction(expr); + WrapInFunction(Var("expr", ty.f32(), ast::StorageClass::kNone, expr)); GeneratorImpl& gen = SanitizeAndBuild(); @@ -141,7 +141,7 @@ Data str; [numthreads(1, 1, 1)] void test_function() { - float tint_symbol = str.mem; + float expr = str.mem; return; }