From 9481156eb941ad209bc6a004082964a975c040e5 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 30 Apr 2021 10:16:45 +0000 Subject: [PATCH] Fix Undefined Behaviour All caused by calling Castable::As<> on nullptr objects. Bug: tint:760 Change-Id: I0a408b3cd58086cfeab5a1af34d643f50f304948 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49523 Reviewed-by: Corentin Wallez Commit-Queue: Ben Clayton --- src/castable.h | 15 +++++++++---- src/program_builder.cc | 4 ++-- src/reader/wgsl/parser_impl.cc | 2 +- src/sem/info.h | 16 ++++++++++---- src/transform/decompose_storage_access.cc | 3 +-- src/transform/first_index_offset.cc | 27 +++++++++++++---------- 6 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/castable.h b/src/castable.h index 884ba5f619..5bb8fdb58b 100644 --- a/src/castable.h +++ b/src/castable.h @@ -173,13 +173,20 @@ inline bool IsAnyOf(FROM* obj) { /// @see CastFlags template inline TO* As(FROM* obj) { - using castable = - typename std::conditional::value, const CastableBase, - CastableBase>::type; - auto* as_castable = static_cast(obj); + auto* as_castable = static_cast(obj); return Is(obj) ? static_cast(as_castable) : nullptr; } +/// @returns obj dynamically cast to the type `TO` or `nullptr` if +/// this object does not derive from `TO`. +/// @param obj the object to cast from +/// @see CastFlags +template +inline const TO* As(const FROM* obj) { + auto* as_castable = static_cast(obj); + return Is(obj) ? static_cast(as_castable) : nullptr; +} + /// CastableBase is the base class for all Castable objects. /// It is not encouraged to directly derive from CastableBase without using the /// Castable helper template. diff --git a/src/program_builder.cc b/src/program_builder.cc index 7689f98b50..c16f878dff 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -132,10 +132,10 @@ ast::ConstructorExpression* ProgramBuilder::ConstructValueFilledWith( typ::Type ProgramBuilder::TypesBuilder::MaybeCreateTypename( typ::Type type) const { - if (auto* alias = type.ast->As()) { + if (auto* alias = As(type.ast)) { return {builder->create(alias->symbol()), type.sem}; } - if (auto* str = type.ast->As()) { + if (auto* str = As(type.ast)) { return {builder->create(str->name()), type.sem}; } return type; diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc index d303610668..99d75486a3 100644 --- a/src/reader/wgsl/parser_impl.cc +++ b/src/reader/wgsl/parser_impl.cc @@ -1396,7 +1396,7 @@ Maybe ParserImpl::function_header() { return_type = type.value; } - if (return_type.ast->Is()) { + if (Is(return_type.ast)) { // crbug.com/tint/677: void has been removed from the language deprecated(tok.source(), "omit '-> void' for functions that do not return a value"); diff --git a/src/sem/info.h b/src/sem/info.h index 412e3ef856..b442df2359 100644 --- a/src/sem/info.h +++ b/src/sem/info.h @@ -15,6 +15,7 @@ #ifndef SRC_SEM_INFO_H_ #define SRC_SEM_INFO_H_ +#include #include #include "src/debug.h" @@ -26,6 +27,9 @@ namespace sem { /// Info holds all the resolved semantic information for a Program. class Info { + /// Placeholder type used by Get() to provide a default value for EXPLICIT_SEM + using InferFromAST = std::nullptr_t; + public: /// Constructor Info(); @@ -44,14 +48,18 @@ class Info { /// Get looks up the semantic information for the AST or type node `node`. /// @param node the AST or type node /// @returns a pointer to the semantic node if found, otherwise nullptr - template > - const SEM* Get(const AST_OR_TYPE* node) const { + template ::value, + SemanticNodeTypeFor, + SEM>> + const RESULT* Get(const AST_OR_TYPE* node) const { auto it = map.find(node); if (it == map.end()) { return nullptr; } - return it->second->template As(); + return As(it->second); } /// Add registers the semantic node `sem_node` for the AST or type node diff --git a/src/transform/decompose_storage_access.cc b/src/transform/decompose_storage_access.cc index 44a768b0d3..50b51a571f 100644 --- a/src/transform/decompose_storage_access.cc +++ b/src/transform/decompose_storage_access.cc @@ -627,8 +627,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) { for (auto* node : ctx.src->ASTNodes().Objects()) { if (auto* ident = node->As()) { // X - auto* expr = sem.Get(ident); - if (auto* var = expr->As()) { + if (auto* var = sem.Get(ident)) { if (var->Variable()->StorageClass() == ast::StorageClass::kStorage) { // Variable to a storage buffer state.AddAccesss(ident, { diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 20e869db62..1ef89e18ec 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -148,19 +148,22 @@ Output FirstIndexOffset::Run(const Program* in, const DataMap& data) { // Fix up all references to the builtins with the offsets ctx.ReplaceAll([=, &ctx](ast::Expression* expr) -> ast::Expression* { - auto* sem = ctx.src->Sem().Get(expr); - if (auto* user = sem->As()) { - auto it = builtin_vars.find(user->Variable()); - if (it != builtin_vars.end()) { - return ctx.dst->Add(ctx.CloneWithoutTransform(expr), - ctx.dst->MemberAccessor(buffer_name, it->second)); + if (auto* sem = ctx.src->Sem().Get(expr)) { + if (auto* user = sem->As()) { + auto it = builtin_vars.find(user->Variable()); + if (it != builtin_vars.end()) { + return ctx.dst->Add( + ctx.CloneWithoutTransform(expr), + ctx.dst->MemberAccessor(buffer_name, it->second)); + } } - } - if (auto* access = sem->As()) { - auto it = builtin_members.find(access->Member()); - if (it != builtin_members.end()) { - return ctx.dst->Add(ctx.CloneWithoutTransform(expr), - ctx.dst->MemberAccessor(buffer_name, it->second)); + if (auto* access = sem->As()) { + auto it = builtin_members.find(access->Member()); + if (it != builtin_members.end()) { + return ctx.dst->Add( + ctx.CloneWithoutTransform(expr), + ctx.dst->MemberAccessor(buffer_name, it->second)); + } } } // Not interested in this experssion. Just clone.