Move TypeManager from tint::Context to ast::Module

Bug: tint:307
Bug: tint:337
Change-Id: I726cdf89182813ba6f468f8ac35e5d44b22e1e1f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33666
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton 2020-11-23 19:50:55 +00:00 committed by Commit Bot service account
parent 3e67c5dba6
commit 0fb5168fc7
28 changed files with 211 additions and 176 deletions

View File

@ -26,7 +26,7 @@ TypesBuilder::TypesBuilder(TypeManager* tm)
tm_(tm) {} tm_(tm) {}
Builder::Builder(tint::Context* c, tint::ast::Module* m) Builder::Builder(tint::Context* c, tint::ast::Module* m)
: ctx(c), mod(m), ty(&c->type_mgr()) {} : ctx(c), mod(m), ty(&m->type_mgr()) {}
Builder::~Builder() = default; Builder::~Builder() = default;
ast::Variable* Builder::Var(const std::string& name, ast::Variable* Builder::Var(const std::string& name,

View File

@ -22,6 +22,7 @@
#include "src/ast/function.h" #include "src/ast/function.h"
#include "src/ast/type/alias_type.h" #include "src/ast/type/alias_type.h"
#include "src/ast/type_manager.h"
#include "src/ast/variable.h" #include "src/ast/variable.h"
namespace tint { namespace tint {
@ -77,6 +78,9 @@ class Module {
/// @returns a string representation of the module /// @returns a string representation of the module
std::string to_str() const; std::string to_str() const;
/// @returns the Type Manager
ast::TypeManager& type_mgr() { return type_mgr_; }
/// Creates a new `ast::Node` owned by the Module. When the Module is /// Creates a new `ast::Node` owned by the Module. When the Module is
/// destructed, the `ast::Node` will also be destructed. /// destructed, the `ast::Node` will also be destructed.
/// @param args the arguments to pass to the type constructor /// @param args the arguments to pass to the type constructor
@ -99,6 +103,7 @@ class Module {
std::vector<type::Type*> constructed_types_; std::vector<type::Type*> constructed_types_;
FunctionList functions_; FunctionList functions_;
std::vector<std::unique_ptr<ast::Node>> ast_nodes_; std::vector<std::unique_ptr<ast::Node>> ast_nodes_;
ast::TypeManager type_mgr_;
}; };
} // namespace ast } // namespace ast

View File

@ -79,10 +79,10 @@ TEST_F(StorageTextureTypeTest, TypeName) {
TEST_F(StorageTextureTypeTest, F32Type) { TEST_F(StorageTextureTypeTest, F32Type) {
Context ctx; Context ctx;
ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>( ast::Module mod;
ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
TextureDimension::k2dArray, AccessControl::kReadOnly, TextureDimension::k2dArray, AccessControl::kReadOnly,
ImageFormat::kRgba32Float)); ImageFormat::kRgba32Float));
ast::Module mod;
TypeDeterminer td(&ctx, &mod); TypeDeterminer td(&ctx, &mod);
ASSERT_TRUE(td.Determine()) << td.error(); ASSERT_TRUE(td.Determine()) << td.error();
@ -93,10 +93,10 @@ TEST_F(StorageTextureTypeTest, F32Type) {
TEST_F(StorageTextureTypeTest, U32Type) { TEST_F(StorageTextureTypeTest, U32Type) {
Context ctx; Context ctx;
ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>( ast::Module mod;
ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
TextureDimension::k2dArray, AccessControl::kReadOnly, TextureDimension::k2dArray, AccessControl::kReadOnly,
ImageFormat::kRgba8Unorm)); ImageFormat::kRgba8Unorm));
ast::Module mod;
TypeDeterminer td(&ctx, &mod); TypeDeterminer td(&ctx, &mod);
ASSERT_TRUE(td.Determine()) << td.error(); ASSERT_TRUE(td.Determine()) << td.error();
@ -107,10 +107,10 @@ TEST_F(StorageTextureTypeTest, U32Type) {
TEST_F(StorageTextureTypeTest, I32Type) { TEST_F(StorageTextureTypeTest, I32Type) {
Context ctx; Context ctx;
ast::type::Type* s = ctx.type_mgr().Get(std::make_unique<StorageTextureType>( ast::Module mod;
ast::type::Type* s = mod.type_mgr().Get(std::make_unique<StorageTextureType>(
TextureDimension::k2dArray, AccessControl::kReadOnly, TextureDimension::k2dArray, AccessControl::kReadOnly,
ImageFormat::kRgba32Sint)); ImageFormat::kRgba32Sint));
ast::Module mod;
TypeDeterminer td(&ctx, &mod); TypeDeterminer td(&ctx, &mod);
ASSERT_TRUE(td.Determine()) << td.error(); ASSERT_TRUE(td.Determine()) << td.error();

View File

@ -20,7 +20,7 @@ namespace tint {
namespace ast { namespace ast {
TypeManager::TypeManager() = default; TypeManager::TypeManager() = default;
TypeManager::TypeManager(TypeManager&&) = default;
TypeManager::~TypeManager() = default; TypeManager::~TypeManager() = default;
void TypeManager::Reset() { void TypeManager::Reset() {

View File

@ -29,6 +29,8 @@ namespace ast {
class TypeManager { class TypeManager {
public: public:
TypeManager(); TypeManager();
/// Move constructor
TypeManager(TypeManager&&);
~TypeManager(); ~TypeManager();
/// Clears all registered types. /// Clears all registered types.

View File

@ -27,8 +27,4 @@ Context::Context(std::unique_ptr<Namer> namer) : namer_(std::move(namer)) {}
Context::~Context() = default; Context::~Context() = default;
void Context::Reset() {
type_mgr_.Reset();
}
} // namespace tint } // namespace tint

View File

@ -22,7 +22,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "src/ast/type_manager.h"
#include "src/namer.h" #include "src/namer.h"
namespace tint { namespace tint {
@ -42,17 +41,11 @@ class Context {
explicit Context(std::unique_ptr<Namer> namer); explicit Context(std::unique_ptr<Namer> namer);
/// Destructor /// Destructor
~Context(); ~Context();
/// Resets the state of this context.
void Reset();
/// @returns the Type Manager
ast::TypeManager& type_mgr() { return type_mgr_; }
/// @returns the namer object /// @returns the namer object
Namer* namer() const { return namer_.get(); } Namer* namer() const { return namer_.get(); }
private: private:
ast::TypeManager type_mgr_;
std::unique_ptr<Namer> namer_; std::unique_ptr<Namer> namer_;
}; };

View File

@ -3253,7 +3253,7 @@ ast::type::Type* FunctionEmitter::RemapStorageClass(ast::type::Type* type,
const auto* ast_ptr_type = type->AsPointer(); const auto* ast_ptr_type = type->AsPointer();
const auto sc = GetStorageClassForPointerValue(result_id); const auto sc = GetStorageClassForPointerValue(result_id);
if (ast_ptr_type->storage_class() != sc) { if (ast_ptr_type->storage_class() != sc) {
return parser_impl_.context().type_mgr().Get( return parser_impl_.get_module().type_mgr().Get(
std::make_unique<ast::type::PointerType>(ast_ptr_type->type(), sc)); std::make_unique<ast::type::PointerType>(ast_ptr_type->type(), sc));
} }
} }

View File

@ -196,7 +196,8 @@ ParserImpl::ParserImpl(Context* ctx, const std::vector<uint32_t>& spv_binary)
: Reader(ctx), : Reader(ctx),
spv_binary_(spv_binary), spv_binary_(spv_binary),
fail_stream_(&success_, &errors_), fail_stream_(&success_, &errors_),
bool_type_(ctx->type_mgr().Get(std::make_unique<ast::type::BoolType>())), bool_type_(
ast_module_.type_mgr().Get(std::make_unique<ast::type::BoolType>())),
namer_(fail_stream_), namer_(fail_stream_),
enum_converter_(fail_stream_), enum_converter_(fail_stream_),
tools_context_(kInputEnv) { tools_context_(kInputEnv) {
@ -285,7 +286,8 @@ ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) {
switch (spirv_type->kind()) { switch (spirv_type->kind()) {
case spvtools::opt::analysis::Type::kVoid: case spvtools::opt::analysis::Type::kVoid:
return save(ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>())); return save(
ast_module_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
case spvtools::opt::analysis::Type::kBool: case spvtools::opt::analysis::Type::kBool:
return save(bool_type_); return save(bool_type_);
case spvtools::opt::analysis::Type::kInteger: case spvtools::opt::analysis::Type::kInteger:
@ -315,7 +317,8 @@ ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) {
case spvtools::opt::analysis::Type::kImage: case spvtools::opt::analysis::Type::kImage:
// Fake it for sampler and texture types. These are handled in an // Fake it for sampler and texture types. These are handled in an
// entirely different way. // entirely different way.
return save(ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>())); return save(
ast_module_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
default: default:
break; break;
} }
@ -649,9 +652,9 @@ ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Integer* int_ty) { const spvtools::opt::analysis::Integer* int_ty) {
if (int_ty->width() == 32) { if (int_ty->width() == 32) {
auto* signed_ty = auto* signed_ty =
ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>()); ast_module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto* unsigned_ty = auto* unsigned_ty =
ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>()); ast_module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
signed_type_for_[unsigned_ty] = signed_ty; signed_type_for_[unsigned_ty] = signed_ty;
unsigned_type_for_[signed_ty] = unsigned_ty; unsigned_type_for_[signed_ty] = unsigned_ty;
return int_ty->IsSigned() ? signed_ty : unsigned_ty; return int_ty->IsSigned() ? signed_ty : unsigned_ty;
@ -663,7 +666,7 @@ ast::type::Type* ParserImpl::ConvertType(
ast::type::Type* ParserImpl::ConvertType( ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Float* float_ty) { const spvtools::opt::analysis::Float* float_ty) {
if (float_ty->width() == 32) { if (float_ty->width() == 32) {
return ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()); return ast_module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
} }
Fail() << "unhandled float width: " << float_ty->width(); Fail() << "unhandled float width: " << float_ty->width();
return nullptr; return nullptr;
@ -676,18 +679,18 @@ ast::type::Type* ParserImpl::ConvertType(
if (ast_elem_ty == nullptr) { if (ast_elem_ty == nullptr) {
return nullptr; return nullptr;
} }
auto* this_ty = ctx_.type_mgr().Get( auto* this_ty = ast_module_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(ast_elem_ty, num_elem)); std::make_unique<ast::type::VectorType>(ast_elem_ty, num_elem));
// Generate the opposite-signedness vector type, if this type is integral. // Generate the opposite-signedness vector type, if this type is integral.
if (unsigned_type_for_.count(ast_elem_ty)) { if (unsigned_type_for_.count(ast_elem_ty)) {
auto* other_ty = auto* other_ty =
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
unsigned_type_for_[ast_elem_ty], num_elem)); unsigned_type_for_[ast_elem_ty], num_elem));
signed_type_for_[other_ty] = this_ty; signed_type_for_[other_ty] = this_ty;
unsigned_type_for_[this_ty] = other_ty; unsigned_type_for_[this_ty] = other_ty;
} else if (signed_type_for_.count(ast_elem_ty)) { } else if (signed_type_for_.count(ast_elem_ty)) {
auto* other_ty = auto* other_ty =
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
signed_type_for_[ast_elem_ty], num_elem)); signed_type_for_[ast_elem_ty], num_elem));
unsigned_type_for_[other_ty] = this_ty; unsigned_type_for_[other_ty] = this_ty;
signed_type_for_[this_ty] = other_ty; signed_type_for_[this_ty] = other_ty;
@ -705,7 +708,7 @@ ast::type::Type* ParserImpl::ConvertType(
if (ast_scalar_ty == nullptr) { if (ast_scalar_ty == nullptr) {
return nullptr; return nullptr;
} }
return ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>( return ast_module_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
ast_scalar_ty, num_rows, num_columns)); ast_scalar_ty, num_rows, num_columns));
} }
@ -719,7 +722,7 @@ ast::type::Type* ParserImpl::ConvertType(
if (!ApplyArrayDecorations(rtarr_ty, ast_type.get())) { if (!ApplyArrayDecorations(rtarr_ty, ast_type.get())) {
return nullptr; return nullptr;
} }
return ctx_.type_mgr().Get(std::move(ast_type)); return ast_module_.type_mgr().Get(std::move(ast_type));
} }
ast::type::Type* ParserImpl::ConvertType( ast::type::Type* ParserImpl::ConvertType(
@ -764,7 +767,7 @@ ast::type::Type* ParserImpl::ConvertType(
if (remap_buffer_block_type_.count(elem_type_id)) { if (remap_buffer_block_type_.count(elem_type_id)) {
remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty)); remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty));
} }
return ctx_.type_mgr().Get(std::move(ast_type)); return ast_module_.type_mgr().Get(std::move(ast_type));
} }
bool ParserImpl::ApplyArrayDecorations( bool ParserImpl::ApplyArrayDecorations(
@ -892,7 +895,7 @@ ast::type::Type* ParserImpl::ConvertType(
auto ast_struct_type = std::make_unique<ast::type::StructType>( auto ast_struct_type = std::make_unique<ast::type::StructType>(
namer_.GetName(type_id), ast_struct); namer_.GetName(type_id), ast_struct);
auto* result = ctx_.type_mgr().Get(std::move(ast_struct_type)); auto* result = ast_module_.type_mgr().Get(std::move(ast_struct_type));
id_to_type_[type_id] = result; id_to_type_[type_id] = result;
if (num_non_writable_members == members.size()) { if (num_non_writable_members == members.size()) {
read_only_struct_types_.insert(result); read_only_struct_types_.insert(result);
@ -932,7 +935,7 @@ ast::type::Type* ParserImpl::ConvertType(
ast_storage_class = ast::StorageClass::kStorageBuffer; ast_storage_class = ast::StorageClass::kStorageBuffer;
remap_buffer_block_type_.insert(type_id); remap_buffer_block_type_.insert(type_id);
} }
return ctx_.type_mgr().Get( return ast_module_.type_mgr().Get(
std::make_unique<ast::type::PointerType>(ast_elem_ty, ast_storage_class)); std::make_unique<ast::type::PointerType>(ast_elem_ty, ast_storage_class));
} }
@ -1062,7 +1065,7 @@ void ParserImpl::MaybeGenerateAlias(uint32_t type_id,
return; return;
} }
const auto name = namer_.GetName(type_id); const auto name = namer_.GetName(type_id);
auto* ast_alias_type = ctx_.type_mgr() auto* ast_alias_type = ast_module_.type_mgr()
.Get(std::make_unique<ast::type::AliasType>( .Get(std::make_unique<ast::type::AliasType>(
name, ast_underlying_type)) name, ast_underlying_type))
->AsAlias(); ->AsAlias();
@ -1166,7 +1169,7 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id,
auto access = read_only_struct_types_.count(type) auto access = read_only_struct_types_.count(type)
? ast::AccessControl::kReadOnly ? ast::AccessControl::kReadOnly
: ast::AccessControl::kReadWrite; : ast::AccessControl::kReadWrite;
type = ctx_.type_mgr().Get( type = ast_module_.type_mgr().Get(
std::make_unique<ast::type::AccessControlType>(access, type)); std::make_unique<ast::type::AccessControlType>(access, type));
} }
@ -1361,7 +1364,7 @@ ast::Expression* ParserImpl::MakeNullValue(ast::type::Type* type) {
const auto* mat_ty = type->AsMatrix(); const auto* mat_ty = type->AsMatrix();
// Matrix components are columns // Matrix components are columns
auto* column_ty = auto* column_ty =
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( ast_module_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
mat_ty->type(), mat_ty->rows())); mat_ty->type(), mat_ty->rows()));
ast::ExpressionList ast_components; ast::ExpressionList ast_components;
for (size_t i = 0; i < mat_ty->columns(); ++i) { for (size_t i = 0; i < mat_ty->columns(); ++i) {
@ -1443,13 +1446,14 @@ ast::type::Type* ParserImpl::GetSignedIntMatchingShape(ast::type::Type* other) {
if (other == nullptr) { if (other == nullptr) {
Fail() << "no type provided"; Fail() << "no type provided";
} }
auto* i32 = ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>()); auto* i32 =
ast_module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
if (other->IsF32() || other->IsU32() || other->IsI32()) { if (other->IsF32() || other->IsU32() || other->IsI32()) {
return i32; return i32;
} }
auto* vec_ty = other->AsVector(); auto* vec_ty = other->AsVector();
if (vec_ty) { if (vec_ty) {
return ctx_.type_mgr().Get( return ast_module_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(i32, vec_ty->size())); std::make_unique<ast::type::VectorType>(i32, vec_ty->size()));
} }
Fail() << "required numeric scalar or vector, but got " << other->type_name(); Fail() << "required numeric scalar or vector, but got " << other->type_name();
@ -1462,13 +1466,14 @@ ast::type::Type* ParserImpl::GetUnsignedIntMatchingShape(
Fail() << "no type provided"; Fail() << "no type provided";
return nullptr; return nullptr;
} }
auto* u32 = ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>()); auto* u32 =
ast_module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
if (other->IsF32() || other->IsU32() || other->IsI32()) { if (other->IsF32() || other->IsU32() || other->IsI32()) {
return u32; return u32;
} }
auto* vec_ty = other->AsVector(); auto* vec_ty = other->AsVector();
if (vec_ty) { if (vec_ty) {
return ctx_.type_mgr().Get( return ast_module_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(u32, vec_ty->size())); std::make_unique<ast::type::VectorType>(u32, vec_ty->size()));
} }
Fail() << "required numeric scalar or vector, but got " << other->type_name(); Fail() << "required numeric scalar or vector, but got " << other->type_name();
@ -1628,7 +1633,7 @@ ast::type::Type* ParserImpl::GetTypeForHandleVar(
ast::type::Type* ast_store_type = nullptr; ast::type::Type* ast_store_type = nullptr;
if (usage.IsSampler()) { if (usage.IsSampler()) {
ast_store_type = ast_store_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::SamplerType>( ast_module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
usage.IsComparisonSampler() usage.IsComparisonSampler()
? ast::type::SamplerKind::kComparisonSampler ? ast::type::SamplerKind::kComparisonSampler
: ast::type::SamplerKind::kSampler)); : ast::type::SamplerKind::kSampler));
@ -1684,16 +1689,16 @@ ast::type::Type* ParserImpl::GetTypeForHandleVar(
// OpImage variable with an OpImage*Dref* instruction. In WGSL we must // OpImage variable with an OpImage*Dref* instruction. In WGSL we must
// treat that as a depth texture. // treat that as a depth texture.
if (image_type->depth() || usage.IsDepthTexture()) { if (image_type->depth() || usage.IsDepthTexture()) {
ast_store_type = ctx_.type_mgr().Get( ast_store_type = ast_module_.type_mgr().Get(
std::make_unique<ast::type::DepthTextureType>(dim)); std::make_unique<ast::type::DepthTextureType>(dim));
} else if (image_type->is_multisampled()) { } else if (image_type->is_multisampled()) {
// Multisampled textures are never depth textures. // Multisampled textures are never depth textures.
ast_store_type = ctx_.type_mgr().Get( ast_store_type = ast_module_.type_mgr().Get(
std::make_unique<ast::type::MultisampledTextureType>( std::make_unique<ast::type::MultisampledTextureType>(
dim, ast_sampled_component_type)); dim, ast_sampled_component_type));
} else { } else {
ast_store_type = ast_store_type = ast_module_.type_mgr().Get(
ctx_.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>( std::make_unique<ast::type::SampledTextureType>(
dim, ast_sampled_component_type)); dim, ast_sampled_component_type));
} }
} else { } else {
@ -1726,7 +1731,7 @@ ast::type::Type* ParserImpl::GetTypeForHandleVar(
if (format == ast::type::ImageFormat::kNone) { if (format == ast::type::ImageFormat::kNone) {
return nullptr; return nullptr;
} }
ast_store_type = ctx_.type_mgr().Get( ast_store_type = ast_module_.type_mgr().Get(
std::make_unique<ast::type::StorageTextureType>(dim, access, format)); std::make_unique<ast::type::StorageTextureType>(dim, access, format));
} }
} else { } else {
@ -1736,7 +1741,7 @@ ast::type::Type* ParserImpl::GetTypeForHandleVar(
return nullptr; return nullptr;
} }
// Form the pointer type. // Form the pointer type.
return ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>( return ast_module_.type_mgr().Get(std::make_unique<ast::type::PointerType>(
ast_store_type, ast::StorageClass::kUniformConstant)); ast_store_type, ast::StorageClass::kUniformConstant));
} }

View File

@ -162,8 +162,8 @@ struct BlockCounters {
} // namespace } // namespace
ParserImpl::ParserImpl(Context* ctx, Source::File const* file) ParserImpl::ParserImpl(Context*, Source::File const* file)
: ctx_(*ctx), lexer_(std::make_unique<Lexer>(file)) {} : lexer_(std::make_unique<Lexer>(file)) {}
ParserImpl::~ParserImpl() = default; ParserImpl::~ParserImpl() = default;
@ -308,7 +308,7 @@ Expect<bool> ParserImpl::expect_global_decl() {
if (!expect("struct declaration", Token::Type::kSemicolon)) if (!expect("struct declaration", Token::Type::kSemicolon))
return Failure::kErrored; return Failure::kErrored;
auto* type = ctx_.type_mgr().Get(std::move(str.value)); auto* type = module_.type_mgr().Get(std::move(str.value));
register_constructed(type->AsStruct()->name(), type); register_constructed(type->AsStruct()->name(), type);
module_.AddConstructedType(type); module_.AddConstructedType(type);
return true; return true;
@ -462,8 +462,9 @@ Maybe<ast::type::Type*> ParserImpl::texture_sampler_types() {
if (subtype.errored) if (subtype.errored)
return Failure::kErrored; return Failure::kErrored;
return ctx_.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>( return module_.type_mgr().Get(
dim.value, subtype.value)); std::make_unique<ast::type::SampledTextureType>(dim.value,
subtype.value));
} }
auto ms_dim = multisampled_texture_type(); auto ms_dim = multisampled_texture_type();
@ -474,7 +475,7 @@ Maybe<ast::type::Type*> ParserImpl::texture_sampler_types() {
if (subtype.errored) if (subtype.errored)
return Failure::kErrored; return Failure::kErrored;
return ctx_.type_mgr().Get( return module_.type_mgr().Get(
std::make_unique<ast::type::MultisampledTextureType>(ms_dim.value, std::make_unique<ast::type::MultisampledTextureType>(ms_dim.value,
subtype.value)); subtype.value));
} }
@ -489,8 +490,9 @@ Maybe<ast::type::Type*> ParserImpl::texture_sampler_types() {
if (format.errored) if (format.errored)
return Failure::kErrored; return Failure::kErrored;
return ctx_.type_mgr().Get(std::make_unique<ast::type::StorageTextureType>( return module_.type_mgr().Get(
storage->first, storage->second, format.value)); std::make_unique<ast::type::StorageTextureType>(
storage->first, storage->second, format.value));
} }
return Failure::kNoMatch; return Failure::kNoMatch;
@ -501,11 +503,11 @@ Maybe<ast::type::Type*> ParserImpl::texture_sampler_types() {
// | SAMPLER_COMPARISON // | SAMPLER_COMPARISON
Maybe<ast::type::Type*> ParserImpl::sampler_type() { Maybe<ast::type::Type*> ParserImpl::sampler_type() {
if (match(Token::Type::kSampler)) if (match(Token::Type::kSampler))
return ctx_.type_mgr().Get(std::make_unique<ast::type::SamplerType>( return module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
ast::type::SamplerKind::kSampler)); ast::type::SamplerKind::kSampler));
if (match(Token::Type::kComparisonSampler)) if (match(Token::Type::kComparisonSampler))
return ctx_.type_mgr().Get(std::make_unique<ast::type::SamplerType>( return module_.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
ast::type::SamplerKind::kComparisonSampler)); ast::type::SamplerKind::kComparisonSampler));
return Failure::kNoMatch; return Failure::kNoMatch;
@ -634,19 +636,19 @@ ParserImpl::storage_texture_type() {
// | TEXTURE_DEPTH_CUBE_ARRAY // | TEXTURE_DEPTH_CUBE_ARRAY
Maybe<ast::type::Type*> ParserImpl::depth_texture_type() { Maybe<ast::type::Type*> ParserImpl::depth_texture_type() {
if (match(Token::Type::kTextureDepth2d)) if (match(Token::Type::kTextureDepth2d))
return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>( return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
ast::type::TextureDimension::k2d)); ast::type::TextureDimension::k2d));
if (match(Token::Type::kTextureDepth2dArray)) if (match(Token::Type::kTextureDepth2dArray))
return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>( return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
ast::type::TextureDimension::k2dArray)); ast::type::TextureDimension::k2dArray));
if (match(Token::Type::kTextureDepthCube)) if (match(Token::Type::kTextureDepthCube))
return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>( return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
ast::type::TextureDimension::kCube)); ast::type::TextureDimension::kCube));
if (match(Token::Type::kTextureDepthCubeArray)) if (match(Token::Type::kTextureDepthCubeArray))
return ctx_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>( return module_.type_mgr().Get(std::make_unique<ast::type::DepthTextureType>(
ast::type::TextureDimension::kCubeArray)); ast::type::TextureDimension::kCubeArray));
return Failure::kNoMatch; return Failure::kNoMatch;
@ -832,7 +834,7 @@ Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_variable_ident_decl(
for (auto* deco : access_decos) { for (auto* deco : access_decos) {
// If we have an access control decoration then we take it and wrap our // If we have an access control decoration then we take it and wrap our
// type up with that decoration // type up with that decoration
ty = ctx_.type_mgr().Get(std::make_unique<ast::type::AccessControlType>( ty = module_.type_mgr().Get(std::make_unique<ast::type::AccessControlType>(
deco->AsAccess()->value(), ty)); deco->AsAccess()->value(), ty));
} }
@ -892,7 +894,7 @@ Maybe<ast::type::Type*> ParserImpl::type_alias() {
if (!type.matched) if (!type.matched)
return add_error(peek(), "invalid type alias"); return add_error(peek(), "invalid type alias");
auto* alias = ctx_.type_mgr().Get( auto* alias = module_.type_mgr().Get(
std::make_unique<ast::type::AliasType>(name.value, type.value)); std::make_unique<ast::type::AliasType>(name.value, type.value));
register_constructed(name.value, alias); register_constructed(name.value, alias);
@ -951,16 +953,16 @@ Maybe<ast::type::Type*> ParserImpl::type_decl(ast::DecorationList& decos) {
} }
if (match(Token::Type::kBool)) if (match(Token::Type::kBool))
return ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); return module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
if (match(Token::Type::kF32)) if (match(Token::Type::kF32))
return ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()); return module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
if (match(Token::Type::kI32)) if (match(Token::Type::kI32))
return ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>()); return module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
if (match(Token::Type::kU32)) if (match(Token::Type::kU32))
return ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>()); return module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
if (t.IsVec2() || t.IsVec3() || t.IsVec4()) { if (t.IsVec2() || t.IsVec3() || t.IsVec4()) {
next(); // Consume the peek next(); // Consume the peek
@ -1018,7 +1020,7 @@ Expect<ast::type::Type*> ParserImpl::expect_type_decl_pointer() {
if (subtype.errored) if (subtype.errored)
return Failure::kErrored; return Failure::kErrored;
return ctx_.type_mgr().Get( return module_.type_mgr().Get(
std::make_unique<ast::type::PointerType>(subtype.value, sc.value)); std::make_unique<ast::type::PointerType>(subtype.value, sc.value));
}); });
} }
@ -1036,7 +1038,7 @@ Expect<ast::type::Type*> ParserImpl::expect_type_decl_vector(Token t) {
if (subtype.errored) if (subtype.errored)
return Failure::kErrored; return Failure::kErrored;
return ctx_.type_mgr().Get( return module_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(subtype.value, count)); std::make_unique<ast::type::VectorType>(subtype.value, count));
} }
@ -1059,7 +1061,7 @@ Expect<ast::type::Type*> ParserImpl::expect_type_decl_array(
auto ty = std::make_unique<ast::type::ArrayType>(subtype.value, size); auto ty = std::make_unique<ast::type::ArrayType>(subtype.value, size);
ty->set_decorations(std::move(decos)); ty->set_decorations(std::move(decos));
return ctx_.type_mgr().Get(std::move(ty)); return module_.type_mgr().Get(std::move(ty));
}); });
} }
@ -1083,7 +1085,7 @@ Expect<ast::type::Type*> ParserImpl::expect_type_decl_matrix(Token t) {
if (subtype.errored) if (subtype.errored)
return Failure::kErrored; return Failure::kErrored;
return ctx_.type_mgr().Get( return module_.type_mgr().Get(
std::make_unique<ast::type::MatrixType>(subtype.value, rows, columns)); std::make_unique<ast::type::MatrixType>(subtype.value, rows, columns));
} }
@ -1252,7 +1254,7 @@ Maybe<ast::Function*> ParserImpl::function_decl(ast::DecorationList& decos) {
// | VOID // | VOID
Maybe<ast::type::Type*> ParserImpl::function_type_decl() { Maybe<ast::type::Type*> ParserImpl::function_type_decl() {
if (match(Token::Type::kVoid)) if (match(Token::Type::kVoid))
return ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>()); return module_.type_mgr().Get(std::make_unique<ast::type::VoidType>());
return type_decl(); return type_decl();
} }
@ -2611,23 +2613,25 @@ Maybe<ast::AssignmentStatement*> ParserImpl::assignment_stmt() {
Maybe<ast::Literal*> ParserImpl::const_literal() { Maybe<ast::Literal*> ParserImpl::const_literal() {
auto t = peek(); auto t = peek();
if (match(Token::Type::kTrue)) { if (match(Token::Type::kTrue)) {
auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); auto* type =
module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
return create<ast::BoolLiteral>(type, true); return create<ast::BoolLiteral>(type, true);
} }
if (match(Token::Type::kFalse)) { if (match(Token::Type::kFalse)) {
auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); auto* type =
module_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
return create<ast::BoolLiteral>(type, false); return create<ast::BoolLiteral>(type, false);
} }
if (match(Token::Type::kSintLiteral)) { if (match(Token::Type::kSintLiteral)) {
auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>()); auto* type = module_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
return create<ast::SintLiteral>(type, t.to_i32()); return create<ast::SintLiteral>(type, t.to_i32());
} }
if (match(Token::Type::kUintLiteral)) { if (match(Token::Type::kUintLiteral)) {
auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>()); auto* type = module_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
return create<ast::UintLiteral>(type, t.to_u32()); return create<ast::UintLiteral>(type, t.to_u32());
} }
if (match(Token::Type::kFloatLiteral)) { if (match(Token::Type::kFloatLiteral)) {
auto* type = ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()); auto* type = module_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
return create<ast::FloatLiteral>(type, t.to_f32()); return create<ast::FloatLiteral>(type, t.to_f32());
} }
return Failure::kNoMatch; return Failure::kNoMatch;

View File

@ -251,6 +251,9 @@ class ParserImpl {
/// @returns the module. The module in the parser will be reset after this. /// @returns the module. The module in the parser will be reset after this.
ast::Module module() { return std::move(module_); } ast::Module module() { return std::move(module_); }
/// @returns a pointer to the module, without resetting it.
ast::Module& get_module() { return module_; }
/// @returns the next token /// @returns the next token
Token next(); Token next();
/// @returns the next token without advancing /// @returns the next token without advancing
@ -768,7 +771,6 @@ class ParserImpl {
return module_.create<T>(std::forward<ARGS>(args)...); return module_.create<T>(std::forward<ARGS>(args)...);
} }
Context& ctx_;
diag::List diags_; diag::List diags_;
std::unique_ptr<Lexer> lexer_; std::unique_ptr<Lexer> lexer_;
std::deque<Token> token_queue_; std::deque<Token> token_queue_;

View File

@ -27,9 +27,11 @@ namespace wgsl {
namespace { namespace {
TEST_F(ParserImplTest, FunctionTypeDecl_Void) { TEST_F(ParserImplTest, FunctionTypeDecl_Void) {
auto* v = tm()->Get(std::make_unique<ast::type::VoidType>());
auto p = parser("void"); auto p = parser("void");
auto& mod = p->get_module();
auto* v = mod.type_mgr().Get(std::make_unique<ast::type::VoidType>());
auto e = p->function_type_decl(); auto e = p->function_type_decl();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
@ -38,10 +40,13 @@ TEST_F(ParserImplTest, FunctionTypeDecl_Void) {
} }
TEST_F(ParserImplTest, FunctionTypeDecl_Type) { TEST_F(ParserImplTest, FunctionTypeDecl_Type) {
auto* f32 = tm()->Get(std::make_unique<ast::type::F32Type>());
auto* vec2 = tm()->Get(std::make_unique<ast::type::VectorType>(f32, 2));
auto p = parser("vec2<f32>"); auto p = parser("vec2<f32>");
auto& mod = p->get_module();
auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
auto* vec2 =
mod.type_mgr().Get(std::make_unique<ast::type::VectorType>(f32, 2));
auto e = p->function_type_decl(); auto e = p->function_type_decl();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);

View File

@ -34,7 +34,7 @@ TEST_F(ParserImplTest, GlobalDecl_GlobalVariable) {
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.global_variables().size(), 1u); ASSERT_EQ(m.global_variables().size(), 1u);
auto* v = m.global_variables()[0]; auto* v = m.global_variables()[0];
@ -60,7 +60,7 @@ TEST_F(ParserImplTest, GlobalDecl_GlobalConstant) {
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.global_variables().size(), 1u); ASSERT_EQ(m.global_variables().size(), 1u);
auto* v = m.global_variables()[0]; auto* v = m.global_variables()[0];
@ -86,7 +86,7 @@ TEST_F(ParserImplTest, GlobalDecl_TypeAlias) {
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.constructed_types().size(), 1u); ASSERT_EQ(m.constructed_types().size(), 1u);
ASSERT_TRUE(m.constructed_types()[0]->IsAlias()); ASSERT_TRUE(m.constructed_types()[0]->IsAlias());
EXPECT_EQ(m.constructed_types()[0]->AsAlias()->name(), "A"); EXPECT_EQ(m.constructed_types()[0]->AsAlias()->name(), "A");
@ -101,7 +101,7 @@ type B = A;)");
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.constructed_types().size(), 2u); ASSERT_EQ(m.constructed_types().size(), 2u);
ASSERT_TRUE(m.constructed_types()[0]->IsStruct()); ASSERT_TRUE(m.constructed_types()[0]->IsStruct());
auto* str = m.constructed_types()[0]->AsStruct(); auto* str = m.constructed_types()[0]->AsStruct();
@ -132,7 +132,7 @@ TEST_F(ParserImplTest, GlobalDecl_Function) {
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.functions().size(), 1u); ASSERT_EQ(m.functions().size(), 1u);
EXPECT_EQ(m.functions()[0]->name(), "main"); EXPECT_EQ(m.functions()[0]->name(), "main");
} }
@ -142,7 +142,7 @@ TEST_F(ParserImplTest, GlobalDecl_Function_WithDecoration) {
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.functions().size(), 1u); ASSERT_EQ(m.functions().size(), 1u);
EXPECT_EQ(m.functions()[0]->name(), "main"); EXPECT_EQ(m.functions()[0]->name(), "main");
} }
@ -159,7 +159,7 @@ TEST_F(ParserImplTest, GlobalDecl_ParsesStruct) {
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.constructed_types().size(), 1u); ASSERT_EQ(m.constructed_types().size(), 1u);
auto* t = m.constructed_types()[0]; auto* t = m.constructed_types()[0];
@ -174,10 +174,11 @@ TEST_F(ParserImplTest, GlobalDecl_ParsesStruct) {
TEST_F(ParserImplTest, GlobalDecl_Struct_WithStride) { TEST_F(ParserImplTest, GlobalDecl_Struct_WithStride) {
auto p = auto p =
parser("struct A { [[offset(0)]] data: [[stride(4)]] array<f32>; };"); parser("struct A { [[offset(0)]] data: [[stride(4)]] array<f32>; };");
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.constructed_types().size(), 1u); ASSERT_EQ(m.constructed_types().size(), 1u);
auto* t = m.constructed_types()[0]; auto* t = m.constructed_types()[0];
@ -201,7 +202,7 @@ TEST_F(ParserImplTest, GlobalDecl_Struct_WithDecoration) {
p->expect_global_decl(); p->expect_global_decl();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(m.constructed_types().size(), 1u); ASSERT_EQ(m.constructed_types().size(), 1u);
auto* t = m.constructed_types()[0]; auto* t = m.constructed_types()[0];

View File

@ -28,9 +28,11 @@ namespace wgsl {
namespace { namespace {
TEST_F(ParserImplTest, ParamList_Single) { TEST_F(ParserImplTest, ParamList_Single) {
auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
auto p = parser("a : i32"); auto p = parser("a : i32");
auto& mod = p->get_module();
auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto e = p->expect_param_list(); auto e = p->expect_param_list();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored); ASSERT_FALSE(e.errored);
@ -47,11 +49,14 @@ TEST_F(ParserImplTest, ParamList_Single) {
} }
TEST_F(ParserImplTest, ParamList_Multiple) { TEST_F(ParserImplTest, ParamList_Multiple) {
auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
auto* f32 = tm()->Get(std::make_unique<ast::type::F32Type>());
auto* vec2 = tm()->Get(std::make_unique<ast::type::VectorType>(f32, 2));
auto p = parser("a : i32, b: f32, c: vec2<f32>"); auto p = parser("a : i32, b: f32, c: vec2<f32>");
auto& mod = p->get_module();
auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
auto* vec2 =
mod.type_mgr().Get(std::make_unique<ast::type::VectorType>(f32, 2));
auto e = p->expect_param_list(); auto e = p->expect_param_list();
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored); ASSERT_FALSE(e.errored);

View File

@ -190,9 +190,11 @@ TEST_F(ParserImplTest, PrimaryExpression_ParenExpr_InvalidExpr) {
} }
TEST_F(ParserImplTest, PrimaryExpression_Cast) { TEST_F(ParserImplTest, PrimaryExpression_Cast) {
auto* f32_type = tm()->Get(std::make_unique<ast::type::F32Type>());
auto p = parser("f32(1)"); auto p = parser("f32(1)");
auto& mod = p->get_module();
auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
auto e = p->primary_expression(); auto e = p->primary_expression();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
@ -202,7 +204,7 @@ TEST_F(ParserImplTest, PrimaryExpression_Cast) {
ASSERT_TRUE(e->AsConstructor()->IsTypeConstructor()); ASSERT_TRUE(e->AsConstructor()->IsTypeConstructor());
auto* c = e->AsConstructor()->AsTypeConstructor(); auto* c = e->AsConstructor()->AsTypeConstructor();
ASSERT_EQ(c->type(), f32_type); ASSERT_EQ(c->type(), f32);
ASSERT_EQ(c->values().size(), 1u); ASSERT_EQ(c->values().size(), 1u);
ASSERT_TRUE(c->values()[0]->IsConstructor()); ASSERT_TRUE(c->values()[0]->IsConstructor());
@ -210,9 +212,11 @@ TEST_F(ParserImplTest, PrimaryExpression_Cast) {
} }
TEST_F(ParserImplTest, PrimaryExpression_Bitcast) { TEST_F(ParserImplTest, PrimaryExpression_Bitcast) {
auto* f32_type = tm()->Get(std::make_unique<ast::type::F32Type>());
auto p = parser("bitcast<f32>(1)"); auto p = parser("bitcast<f32>(1)");
auto& mod = p->get_module();
auto* f32 = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
auto e = p->primary_expression(); auto e = p->primary_expression();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
@ -221,7 +225,7 @@ TEST_F(ParserImplTest, PrimaryExpression_Bitcast) {
ASSERT_TRUE(e->IsBitcast()); ASSERT_TRUE(e->IsBitcast());
auto* c = e->AsBitcast(); auto* c = e->AsBitcast();
ASSERT_EQ(c->type(), f32_type); ASSERT_EQ(c->type(), f32);
ASSERT_TRUE(c->expr()->IsConstructor()); ASSERT_TRUE(c->expr()->IsConstructor());
ASSERT_TRUE(c->expr()->AsConstructor()->IsScalarConstructor()); ASSERT_TRUE(c->expr()->AsConstructor()->IsScalarConstructor());

View File

@ -23,9 +23,11 @@ namespace wgsl {
namespace { namespace {
TEST_F(ParserImplTest, StructBodyDecl_Parses) { TEST_F(ParserImplTest, StructBodyDecl_Parses) {
auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
auto p = parser("{a : i32;}"); auto p = parser("{a : i32;}");
auto& mod = p->get_module();
auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto m = p->expect_struct_body_decl(); auto m = p->expect_struct_body_decl();
ASSERT_FALSE(p->has_error()); ASSERT_FALSE(p->has_error());
ASSERT_FALSE(m.errored); ASSERT_FALSE(m.errored);

View File

@ -24,9 +24,11 @@ namespace wgsl {
namespace { namespace {
TEST_F(ParserImplTest, StructMember_Parses) { TEST_F(ParserImplTest, StructMember_Parses) {
auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
auto p = parser("a : i32;"); auto p = parser("a : i32;");
auto& mod = p->get_module();
auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto decos = p->decoration_list(); auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored); EXPECT_FALSE(decos.errored);
EXPECT_FALSE(decos.matched); EXPECT_FALSE(decos.matched);
@ -48,9 +50,11 @@ TEST_F(ParserImplTest, StructMember_Parses) {
} }
TEST_F(ParserImplTest, StructMember_ParsesWithDecoration) { TEST_F(ParserImplTest, StructMember_ParsesWithDecoration) {
auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
auto p = parser("[[offset(2)]] a : i32;"); auto p = parser("[[offset(2)]] a : i32;");
auto& mod = p->get_module();
auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto decos = p->decoration_list(); auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored); EXPECT_FALSE(decos.errored);
EXPECT_TRUE(decos.matched); EXPECT_TRUE(decos.matched);
@ -74,10 +78,12 @@ TEST_F(ParserImplTest, StructMember_ParsesWithDecoration) {
} }
TEST_F(ParserImplTest, StructMember_ParsesWithMultipleDecorations) { TEST_F(ParserImplTest, StructMember_ParsesWithMultipleDecorations) {
auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
auto p = parser(R"([[offset(2)]] auto p = parser(R"([[offset(2)]]
[[offset(4)]] a : i32;)"); [[offset(4)]] a : i32;)");
auto& mod = p->get_module();
auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto decos = p->decoration_list(); auto decos = p->decoration_list();
EXPECT_FALSE(decos.errored); EXPECT_FALSE(decos.errored);
EXPECT_TRUE(decos.matched); EXPECT_TRUE(decos.matched);

View File

@ -39,7 +39,7 @@ fn main() -> void {
)"); )");
ASSERT_TRUE(p->Parse()) << p->error(); ASSERT_TRUE(p->Parse()) << p->error();
auto m = p->module(); auto& m = p->get_module();
ASSERT_EQ(1u, m.functions().size()); ASSERT_EQ(1u, m.functions().size());
ASSERT_EQ(1u, m.global_variables().size()); ASSERT_EQ(1u, m.global_variables().size());
} }

View File

@ -45,9 +45,6 @@ class ParserImplTest : public testing::Test {
return impl; return impl;
} }
/// @returns the type manager
ast::TypeManager* tm() { return &(ctx_.type_mgr()); }
private: private:
std::vector<std::unique_ptr<Source::File>> files_; std::vector<std::unique_ptr<Source::File>> files_;
Context ctx_; Context ctx_;
@ -71,9 +68,6 @@ class ParserImplTestWithParam : public testing::TestWithParam<T> {
return impl; return impl;
} }
/// @returns the type manager
ast::TypeManager* tm() { return &(ctx_.type_mgr()); }
private: private:
std::vector<std::unique_ptr<Source::File>> files_; std::vector<std::unique_ptr<Source::File>> files_;
Context ctx_; Context ctx_;

View File

@ -26,9 +26,11 @@ namespace wgsl {
namespace { namespace {
TEST_F(ParserImplTest, TypeDecl_ParsesType) { TEST_F(ParserImplTest, TypeDecl_ParsesType) {
auto* i32 = tm()->Get(std::make_unique<ast::type::I32Type>());
auto p = parser("type a = i32"); auto p = parser("type a = i32");
auto& mod = p->get_module();
auto* i32 = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto t = p->type_alias(); auto t = p->type_alias();
EXPECT_FALSE(p->has_error()); EXPECT_FALSE(p->has_error());
EXPECT_FALSE(t.errored); EXPECT_FALSE(t.errored);

View File

@ -46,10 +46,11 @@ TEST_F(ParserImplTest, TypeDecl_Invalid) {
TEST_F(ParserImplTest, TypeDecl_Identifier) { TEST_F(ParserImplTest, TypeDecl_Identifier) {
auto p = parser("A"); auto p = parser("A");
auto* int_type = tm()->Get(std::make_unique<ast::type::I32Type>()); auto& mod = p->get_module();
// Pre-register to make sure that it's the same type.
auto* int_type = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto* alias_type = auto* alias_type =
tm()->Get(std::make_unique<ast::type::AliasType>("A", int_type)); mod.type_mgr().Get(std::make_unique<ast::type::AliasType>("A", int_type));
p->register_constructed("A", alias_type); p->register_constructed("A", alias_type);
@ -79,7 +80,8 @@ TEST_F(ParserImplTest, TypeDecl_Identifier_NotFound) {
TEST_F(ParserImplTest, TypeDecl_Bool) { TEST_F(ParserImplTest, TypeDecl_Bool) {
auto p = parser("bool"); auto p = parser("bool");
auto* bool_type = tm()->Get(std::make_unique<ast::type::BoolType>()); auto& mod = p->get_module();
auto* bool_type = mod.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto t = p->type_decl(); auto t = p->type_decl();
EXPECT_TRUE(t.matched); EXPECT_TRUE(t.matched);
@ -92,7 +94,8 @@ TEST_F(ParserImplTest, TypeDecl_Bool) {
TEST_F(ParserImplTest, TypeDecl_F32) { TEST_F(ParserImplTest, TypeDecl_F32) {
auto p = parser("f32"); auto p = parser("f32");
auto* float_type = tm()->Get(std::make_unique<ast::type::F32Type>()); auto& mod = p->get_module();
auto* float_type = mod.type_mgr().Get(std::make_unique<ast::type::F32Type>());
auto t = p->type_decl(); auto t = p->type_decl();
EXPECT_TRUE(t.matched); EXPECT_TRUE(t.matched);
@ -105,7 +108,8 @@ TEST_F(ParserImplTest, TypeDecl_F32) {
TEST_F(ParserImplTest, TypeDecl_I32) { TEST_F(ParserImplTest, TypeDecl_I32) {
auto p = parser("i32"); auto p = parser("i32");
auto* int_type = tm()->Get(std::make_unique<ast::type::I32Type>()); auto& mod = p->get_module();
auto* int_type = mod.type_mgr().Get(std::make_unique<ast::type::I32Type>());
auto t = p->type_decl(); auto t = p->type_decl();
EXPECT_TRUE(t.matched); EXPECT_TRUE(t.matched);
@ -118,7 +122,8 @@ TEST_F(ParserImplTest, TypeDecl_I32) {
TEST_F(ParserImplTest, TypeDecl_U32) { TEST_F(ParserImplTest, TypeDecl_U32) {
auto p = parser("u32"); auto p = parser("u32");
auto* uint_type = tm()->Get(std::make_unique<ast::type::U32Type>()); auto& mod = p->get_module();
auto* uint_type = mod.type_mgr().Get(std::make_unique<ast::type::U32Type>());
auto t = p->type_decl(); auto t = p->type_decl();
EXPECT_TRUE(t.matched); EXPECT_TRUE(t.matched);
@ -734,7 +739,8 @@ INSTANTIATE_TEST_SUITE_P(ParserImplTest,
TEST_F(ParserImplTest, TypeDecl_Sampler) { TEST_F(ParserImplTest, TypeDecl_Sampler) {
auto p = parser("sampler"); auto p = parser("sampler");
auto* type = tm()->Get(std::make_unique<ast::type::SamplerType>( auto& mod = p->get_module();
auto* type = mod.type_mgr().Get(std::make_unique<ast::type::SamplerType>(
ast::type::SamplerKind::kSampler)); ast::type::SamplerKind::kSampler));
auto t = p->type_decl(); auto t = p->type_decl();
@ -749,9 +755,11 @@ TEST_F(ParserImplTest, TypeDecl_Sampler) {
TEST_F(ParserImplTest, TypeDecl_Texture_Old) { TEST_F(ParserImplTest, TypeDecl_Texture_Old) {
auto p = parser("texture_sampled_cube<f32>"); auto p = parser("texture_sampled_cube<f32>");
auto& mod = p->get_module();
ast::type::F32Type f32; ast::type::F32Type f32;
auto* type = tm()->Get(std::make_unique<ast::type::SampledTextureType>( auto* type =
ast::type::TextureDimension::kCube, &f32)); mod.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
ast::type::TextureDimension::kCube, &f32));
auto t = p->type_decl(); auto t = p->type_decl();
EXPECT_TRUE(t.matched); EXPECT_TRUE(t.matched);
@ -767,8 +775,10 @@ TEST_F(ParserImplTest, TypeDecl_Texture) {
auto p = parser("texture_cube<f32>"); auto p = parser("texture_cube<f32>");
ast::type::F32Type f32; ast::type::F32Type f32;
auto* type = tm()->Get(std::make_unique<ast::type::SampledTextureType>( auto& mod = p->get_module();
ast::type::TextureDimension::kCube, &f32)); auto* type =
mod.type_mgr().Get(std::make_unique<ast::type::SampledTextureType>(
ast::type::TextureDimension::kCube, &f32));
auto t = p->type_decl(); auto t = p->type_decl();
EXPECT_TRUE(t.matched); EXPECT_TRUE(t.matched);

View File

@ -237,7 +237,7 @@ bool BoundArrayAccessorsTransform::ProcessAccessExpression(
return false; return false;
} }
} else { } else {
auto* u32 = ctx_->type_mgr().Get(std::make_unique<ast::type::U32Type>()); auto* u32 = mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
ast::ExpressionList cast_expr; ast::ExpressionList cast_expr;
cast_expr.push_back(expr->idx_expr()); cast_expr.push_back(expr->idx_expr());

View File

@ -222,7 +222,7 @@ void VertexPullingTransform::AddVertexStorageBuffers() {
ary_decos.push_back(create<ast::StrideDecoration>(4u, Source{})); ary_decos.push_back(create<ast::StrideDecoration>(4u, Source{}));
internal_array->set_decorations(std::move(ary_decos)); internal_array->set_decorations(std::move(ary_decos));
auto* internal_array_type = ctx_->type_mgr().Get(std::move(internal_array)); auto* internal_array_type = mod_->type_mgr().Get(std::move(internal_array));
// Creating the struct type // Creating the struct type
ast::StructMemberList members; ast::StructMemberList members;
@ -236,7 +236,7 @@ void VertexPullingTransform::AddVertexStorageBuffers() {
decos.push_back(create<ast::StructBlockDecoration>(Source{})); decos.push_back(create<ast::StructBlockDecoration>(Source{}));
auto* struct_type = auto* struct_type =
ctx_->type_mgr().Get(std::make_unique<ast::type::StructType>( mod_->type_mgr().Get(std::make_unique<ast::type::StructType>(
kStructName, kStructName,
create<ast::Struct>(std::move(decos), std::move(members)))); create<ast::Struct>(std::move(decos), std::move(members))));
@ -411,21 +411,21 @@ ast::Expression* VertexPullingTransform::AccessVec(uint32_t buffer,
} }
return create<ast::TypeConstructorExpression>( return create<ast::TypeConstructorExpression>(
ctx_->type_mgr().Get( mod_->type_mgr().Get(
std::make_unique<ast::type::VectorType>(base_type, count)), std::make_unique<ast::type::VectorType>(base_type, count)),
std::move(expr_list)); std::move(expr_list));
} }
ast::type::Type* VertexPullingTransform::GetU32Type() { ast::type::Type* VertexPullingTransform::GetU32Type() {
return ctx_->type_mgr().Get(std::make_unique<ast::type::U32Type>()); return mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>());
} }
ast::type::Type* VertexPullingTransform::GetI32Type() { ast::type::Type* VertexPullingTransform::GetI32Type() {
return ctx_->type_mgr().Get(std::make_unique<ast::type::I32Type>()); return mod_->type_mgr().Get(std::make_unique<ast::type::I32Type>());
} }
ast::type::Type* VertexPullingTransform::GetF32Type() { ast::type::Type* VertexPullingTransform::GetF32Type() {
return ctx_->type_mgr().Get(std::make_unique<ast::type::F32Type>()); return mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>());
} }
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default; VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;

View File

@ -48,7 +48,7 @@ class VertexPullingTransformHelper {
void InitBasicModule() { void InitBasicModule() {
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
"main", ast::VariableList{}, "main", ast::VariableList{},
ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>()), mod_->type_mgr().Get(std::make_unique<ast::type::VoidType>()),
create<ast::BlockStatement>()); create<ast::BlockStatement>());
func->add_decoration( func->add_decoration(
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{})); create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}));
@ -81,7 +81,6 @@ class VertexPullingTransformHelper {
mod_->AddGlobalVariable(var); mod_->AddGlobalVariable(var);
} }
Context* ctx() { return &ctx_; }
ast::Module* mod() { return mod_.get(); } ast::Module* mod() { return mod_.get(); }
Manager* manager() { return manager_.get(); } Manager* manager() { return manager_.get(); }
VertexPullingTransform* transform() { return transform_; } VertexPullingTransform* transform() { return transform_; }
@ -128,7 +127,7 @@ TEST_F(VertexPullingTransformTest, Error_InvalidEntryPoint) {
TEST_F(VertexPullingTransformTest, Error_EntryPointWrongStage) { TEST_F(VertexPullingTransformTest, Error_EntryPointWrongStage) {
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
"main", ast::VariableList{}, "main", ast::VariableList{},
ctx()->type_mgr().Get(std::make_unique<ast::type::VoidType>()), mod()->type_mgr().Get(std::make_unique<ast::type::VoidType>()),
create<ast::BlockStatement>()); create<ast::BlockStatement>());
func->add_decoration( func->add_decoration(
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{})); create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}));

View File

@ -83,7 +83,7 @@ void TypeDeterminer::set_referenced_from_function_if_needed(
} }
bool TypeDeterminer::Determine() { bool TypeDeterminer::Determine() {
for (auto& iter : ctx_.type_mgr().types()) { for (auto& iter : mod_->type_mgr().types()) {
auto& type = iter.second; auto& type = iter.second;
if (!type->IsTexture() || !type->AsTexture()->IsStorage()) { if (!type->IsTexture() || !type->AsTexture()->IsStorage()) {
continue; continue;
@ -339,7 +339,7 @@ bool TypeDeterminer::DetermineArrayAccessor(
ret = parent_type->AsVector()->type(); ret = parent_type->AsVector()->type();
} else if (parent_type->IsMatrix()) { } else if (parent_type->IsMatrix()) {
auto* m = parent_type->AsMatrix(); auto* m = parent_type->AsMatrix();
ret = ctx_.type_mgr().Get( ret = mod_->type_mgr().Get(
std::make_unique<ast::type::VectorType>(m->type(), m->rows())); std::make_unique<ast::type::VectorType>(m->type(), m->rows()));
} else { } else {
set_error(expr->source(), "invalid parent type (" + set_error(expr->source(), "invalid parent type (" +
@ -350,14 +350,14 @@ bool TypeDeterminer::DetermineArrayAccessor(
// If we're extracting from a pointer, we return a pointer. // If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) { if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>( ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class())); ret, res->AsPointer()->storage_class()));
} else if (parent_type->IsArray() && } else if (parent_type->IsArray() &&
!parent_type->AsArray()->type()->is_scalar()) { !parent_type->AsArray()->type()->is_scalar()) {
// If we extract a non-scalar from an array then we also get a pointer. We // If we extract a non-scalar from an array then we also get a pointer. We
// will generate a Function storage class variable to store this // will generate a Function storage class variable to store this
// into. // into.
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>( ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, ast::StorageClass::kFunction)); ret, ast::StorageClass::kFunction));
} }
expr->set_result_type(ret); expr->set_result_type(ret);
@ -523,12 +523,12 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
if (ident->intrinsic() == ast::Intrinsic::kAny || if (ident->intrinsic() == ast::Intrinsic::kAny ||
ident->intrinsic() == ast::Intrinsic::kAll) { ident->intrinsic() == ast::Intrinsic::kAll) {
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>())); mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>()));
return true; return true;
} }
if (ident->intrinsic() == ast::Intrinsic::kArrayLength) { if (ident->intrinsic() == ast::Intrinsic::kArrayLength) {
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>())); mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>()));
return true; return true;
} }
if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) { if (ast::intrinsic::IsFloatClassificationIntrinsic(ident->intrinsic())) {
@ -539,12 +539,12 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
} }
auto* bool_type = auto* bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
if (param_type->IsVector()) { if (param_type->IsVector()) {
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
bool_type, param_type->AsVector()->size()))); bool_type, param_type->AsVector()->size())));
} else { } else {
expr->func()->set_result_type(bool_type); expr->func()->set_result_type(bool_type);
@ -667,7 +667,7 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
if (texture->IsDepth()) { if (texture->IsDepth()) {
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>())); mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
return true; return true;
} }
@ -689,12 +689,12 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
return false; return false;
} }
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(type, 4))); mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(type, 4)));
return true; return true;
} }
if (ident->intrinsic() == ast::Intrinsic::kDot) { if (ident->intrinsic() == ast::Intrinsic::kDot) {
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>())); mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
return true; return true;
} }
if (ident->intrinsic() == ast::Intrinsic::kOuterProduct) { if (ident->intrinsic() == ast::Intrinsic::kOuterProduct) {
@ -712,8 +712,8 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
} }
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>( mod_->type_mgr().Get(std::make_unique<ast::type::MatrixType>(
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()), mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()),
param0_type->AsVector()->size(), param1_type->AsVector()->size()))); param0_type->AsVector()->size(), param1_type->AsVector()->size())));
return true; return true;
} }
@ -862,7 +862,7 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
expr->set_result_type(var->type()); expr->set_result_type(var->type());
} else { } else {
expr->set_result_type( expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>( mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
var->type(), var->storage_class()))); var->type(), var->storage_class())));
} }
@ -1055,7 +1055,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
// If we're extracting from a pointer, we return a pointer. // If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) { if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>( ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class())); ret, res->AsPointer()->storage_class()));
} }
} else if (data_type->IsVector()) { } else if (data_type->IsVector()) {
@ -1067,14 +1067,14 @@ bool TypeDeterminer::DetermineMemberAccessor(
ret = vec->type(); ret = vec->type();
// If we're extracting from a pointer, we return a pointer. // If we're extracting from a pointer, we return a pointer.
if (res->IsPointer()) { if (res->IsPointer()) {
ret = ctx_.type_mgr().Get(std::make_unique<ast::type::PointerType>( ret = mod_->type_mgr().Get(std::make_unique<ast::type::PointerType>(
ret, res->AsPointer()->storage_class())); ret, res->AsPointer()->storage_class()));
} }
} else { } else {
// The vector will have a number of components equal to the length of the // The vector will have a number of components equal to the length of the
// swizzle. This assumes the validator will check that the swizzle // swizzle. This assumes the validator will check that the swizzle
// is correct. // is correct.
ret = ctx_.type_mgr().Get( ret = mod_->type_mgr().Get(
std::make_unique<ast::type::VectorType>(vec->type(), size)); std::make_unique<ast::type::VectorType>(vec->type(), size));
} }
} else { } else {
@ -1107,11 +1107,11 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() || expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) { expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto* bool_type = auto* bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); mod_->type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
if (param_type->IsVector()) { if (param_type->IsVector()) {
expr->set_result_type( expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
bool_type, param_type->AsVector()->size()))); bool_type, param_type->AsVector()->size())));
} else { } else {
expr->set_result_type(bool_type); expr->set_result_type(bool_type);
@ -1126,18 +1126,18 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
// checks having been done. // checks having been done.
if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) { if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) {
expr->set_result_type( expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>( mod_->type_mgr().Get(std::make_unique<ast::type::MatrixType>(
lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(), lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
rhs_type->AsMatrix()->columns()))); rhs_type->AsMatrix()->columns())));
} else if (lhs_type->IsMatrix() && rhs_type->IsVector()) { } else if (lhs_type->IsMatrix() && rhs_type->IsVector()) {
auto* mat = lhs_type->AsMatrix(); auto* mat = lhs_type->AsMatrix();
expr->set_result_type(ctx_.type_mgr().Get( expr->set_result_type(mod_->type_mgr().Get(
std::make_unique<ast::type::VectorType>(mat->type(), mat->rows()))); std::make_unique<ast::type::VectorType>(mat->type(), mat->rows())));
} else if (lhs_type->IsVector() && rhs_type->IsMatrix()) { } else if (lhs_type->IsVector() && rhs_type->IsMatrix()) {
auto* mat = rhs_type->AsMatrix(); auto* mat = rhs_type->AsMatrix();
expr->set_result_type( expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( mod_->type_mgr().Get(std::make_unique<ast::type::VectorType>(
mat->type(), mat->columns()))); mat->type(), mat->columns())));
} else if (lhs_type->IsMatrix()) { } else if (lhs_type->IsMatrix()) {
// matrix * scalar // matrix * scalar
@ -1198,7 +1198,7 @@ bool TypeDeterminer::DetermineStorageTextureSubtype(
case ast::type::ImageFormat::kRgba16Uint: case ast::type::ImageFormat::kRgba16Uint:
case ast::type::ImageFormat::kRgba32Uint: { case ast::type::ImageFormat::kRgba32Uint: {
tex->set_type( tex->set_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>())); mod_->type_mgr().Get(std::make_unique<ast::type::U32Type>()));
return true; return true;
} }
@ -1215,7 +1215,7 @@ bool TypeDeterminer::DetermineStorageTextureSubtype(
case ast::type::ImageFormat::kRgba16Sint: case ast::type::ImageFormat::kRgba16Sint:
case ast::type::ImageFormat::kRgba32Sint: { case ast::type::ImageFormat::kRgba32Sint: {
tex->set_type( tex->set_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>())); mod_->type_mgr().Get(std::make_unique<ast::type::I32Type>()));
return true; return true;
} }
@ -1227,7 +1227,7 @@ bool TypeDeterminer::DetermineStorageTextureSubtype(
case ast::type::ImageFormat::kRgba16Float: case ast::type::ImageFormat::kRgba16Float:
case ast::type::ImageFormat::kRgba32Float: { case ast::type::ImageFormat::kRgba32Float: {
tex->set_type( tex->set_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>())); mod_->type_mgr().Get(std::make_unique<ast::type::F32Type>()));
return true; return true;
} }

View File

@ -1787,7 +1787,7 @@ TEST_P(Intrinsic_StorageTextureOperation, TextureLoadRo) {
auto coords_type = get_coords_type(dim, &i32); auto coords_type = get_coords_type(dim, &i32);
ast::type::Type* texture_type = ast::type::Type* texture_type =
ctx->type_mgr().Get(std::make_unique<ast::type::StorageTextureType>( mod->type_mgr().Get(std::make_unique<ast::type::StorageTextureType>(
dim, ast::AccessControl::kReadOnly, format)); dim, ast::AccessControl::kReadOnly, format));
ast::ExpressionList call_params; ast::ExpressionList call_params;
@ -4549,13 +4549,13 @@ TEST_P(TypeDeterminerTextureIntrinsicTest, Call) {
switch (param.texture_kind) { switch (param.texture_kind) {
case ast::intrinsic::test::TextureKind::kRegular: case ast::intrinsic::test::TextureKind::kRegular:
Var("texture", ast::StorageClass::kNone, Var("texture", ast::StorageClass::kNone,
ctx->type_mgr().Get<ast::type::SampledTextureType>( mod->type_mgr().Get<ast::type::SampledTextureType>(
param.texture_dimension, datatype)); param.texture_dimension, datatype));
break; break;
case ast::intrinsic::test::TextureKind::kDepth: case ast::intrinsic::test::TextureKind::kDepth:
Var("texture", ast::StorageClass::kNone, Var("texture", ast::StorageClass::kNone,
ctx->type_mgr().Get<ast::type::DepthTextureType>( mod->type_mgr().Get<ast::type::DepthTextureType>(
param.texture_dimension)); param.texture_dimension));
break; break;
} }

View File

@ -183,13 +183,13 @@ TEST_P(HlslGeneratorIntrinsicTextureTest, Call) {
switch (param.texture_kind) { switch (param.texture_kind) {
case ast::intrinsic::test::TextureKind::kRegular: case ast::intrinsic::test::TextureKind::kRegular:
Var("texture", ast::StorageClass::kNone, Var("texture", ast::StorageClass::kNone,
ctx->type_mgr().Get<ast::type::SampledTextureType>( mod->type_mgr().Get<ast::type::SampledTextureType>(
param.texture_dimension, datatype)); param.texture_dimension, datatype));
break; break;
case ast::intrinsic::test::TextureKind::kDepth: case ast::intrinsic::test::TextureKind::kDepth:
Var("texture", ast::StorageClass::kNone, Var("texture", ast::StorageClass::kNone,
ctx->type_mgr().Get<ast::type::DepthTextureType>( mod->type_mgr().Get<ast::type::DepthTextureType>(
param.texture_dimension)); param.texture_dimension));
break; break;
} }

View File

@ -1618,13 +1618,13 @@ TEST_P(IntrinsicTextureTest, Call) {
switch (param.texture_kind) { switch (param.texture_kind) {
case ast::intrinsic::test::TextureKind::kRegular: case ast::intrinsic::test::TextureKind::kRegular:
tex = Var("texture", ast::StorageClass::kNone, tex = Var("texture", ast::StorageClass::kNone,
ctx->type_mgr().Get<ast::type::SampledTextureType>( mod->type_mgr().Get<ast::type::SampledTextureType>(
param.texture_dimension, datatype)); param.texture_dimension, datatype));
break; break;
case ast::intrinsic::test::TextureKind::kDepth: case ast::intrinsic::test::TextureKind::kDepth:
tex = Var("texture", ast::StorageClass::kNone, tex = Var("texture", ast::StorageClass::kNone,
ctx->type_mgr().Get<ast::type::DepthTextureType>( mod->type_mgr().Get<ast::type::DepthTextureType>(
param.texture_dimension)); param.texture_dimension));
break; break;
} }