Post migration to castable cleanup

Change-Id: I5c47b1736bd850548cb1c9c7a6f69242d8626173
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34460
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2020-12-01 21:07:27 +00:00 committed by Commit Bot service account
parent 782f6a5e3e
commit 1b6a8ce165
19 changed files with 453 additions and 545 deletions

View File

@ -104,25 +104,26 @@ Function::referenced_uniform_variables() const {
std::vector<std::pair<Variable*, Function::BindingInfo>> ret; std::vector<std::pair<Variable*, Function::BindingInfo>> ret;
for (auto* var : referenced_module_variables()) { for (auto* var : referenced_module_variables()) {
if (!var->Is<DecoratedVariable>() || if (var->storage_class() != StorageClass::kUniform) {
var->storage_class() != StorageClass::kUniform) {
continue; continue;
} }
BindingDecoration* binding = nullptr; if (auto* decorated = var->As<DecoratedVariable>()) {
SetDecoration* set = nullptr; BindingDecoration* binding = nullptr;
for (auto* deco : var->As<DecoratedVariable>()->decorations()) { SetDecoration* set = nullptr;
if (auto* b = deco->As<BindingDecoration>()) { for (auto* deco : decorated->decorations()) {
binding = b; if (auto* b = deco->As<BindingDecoration>()) {
} else if (auto* s = deco->As<SetDecoration>()) { binding = b;
set = s; } else if (auto* s = deco->As<SetDecoration>()) {
set = s;
}
}
if (binding == nullptr || set == nullptr) {
continue;
} }
}
if (binding == nullptr || set == nullptr) {
continue;
}
ret.push_back({var, BindingInfo{binding, set}}); ret.push_back({var, BindingInfo{binding, set}});
}
} }
return ret; return ret;
} }
@ -132,25 +133,26 @@ Function::referenced_storagebuffer_variables() const {
std::vector<std::pair<Variable*, Function::BindingInfo>> ret; std::vector<std::pair<Variable*, Function::BindingInfo>> ret;
for (auto* var : referenced_module_variables()) { for (auto* var : referenced_module_variables()) {
if (!var->Is<DecoratedVariable>() || if (var->storage_class() != StorageClass::kStorageBuffer) {
var->storage_class() != StorageClass::kStorageBuffer) {
continue; continue;
} }
BindingDecoration* binding = nullptr; if (auto* decorated = var->As<DecoratedVariable>()) {
SetDecoration* set = nullptr; BindingDecoration* binding = nullptr;
for (auto* deco : var->As<DecoratedVariable>()->decorations()) { SetDecoration* set = nullptr;
if (auto* b = deco->As<BindingDecoration>()) { for (auto* deco : decorated->decorations()) {
binding = b; if (auto* b = deco->As<BindingDecoration>()) {
} else if (auto* s = deco->As<SetDecoration>()) { binding = b;
set = s; } else if (auto* s = deco->As<SetDecoration>()) {
set = s;
}
}
if (binding == nullptr || set == nullptr) {
continue;
} }
}
if (binding == nullptr || set == nullptr) {
continue;
}
ret.push_back({var, BindingInfo{binding, set}}); ret.push_back({var, BindingInfo{binding, set}});
}
} }
return ret; return ret;
} }
@ -160,13 +162,12 @@ Function::referenced_builtin_variables() const {
std::vector<std::pair<Variable*, BuiltinDecoration*>> ret; std::vector<std::pair<Variable*, BuiltinDecoration*>> ret;
for (auto* var : referenced_module_variables()) { for (auto* var : referenced_module_variables()) {
if (!var->Is<DecoratedVariable>()) { if (auto* decorated = var->As<DecoratedVariable>()) {
continue; for (auto* deco : decorated->decorations()) {
} if (auto* builtin = deco->As<BuiltinDecoration>()) {
for (auto* deco : var->As<DecoratedVariable>()->decorations()) { ret.push_back({var, builtin});
if (auto* builtin = deco->As<BuiltinDecoration>()) { break;
ret.push_back({var, builtin}); }
break;
} }
} }
} }
@ -292,25 +293,28 @@ Function::ReferencedSamplerVariablesImpl(type::SamplerKind kind) const {
for (auto* var : referenced_module_variables()) { for (auto* var : referenced_module_variables()) {
auto* unwrapped_type = var->type()->UnwrapIfNeeded(); auto* unwrapped_type = var->type()->UnwrapIfNeeded();
if (!var->Is<DecoratedVariable>() || !unwrapped_type->Is<type::Sampler>() || auto* sampler = unwrapped_type->As<type::Sampler>();
unwrapped_type->As<type::Sampler>()->kind() != kind) { if (sampler == nullptr || sampler->kind() != kind) {
continue; continue;
} }
BindingDecoration* binding = nullptr; if (auto* decorated = var->As<DecoratedVariable>()) {
SetDecoration* set = nullptr; BindingDecoration* binding = nullptr;
for (auto* deco : var->As<DecoratedVariable>()->decorations()) { SetDecoration* set = nullptr;
if (auto* b = deco->As<BindingDecoration>()) { for (auto* deco : decorated->decorations()) {
binding = b; if (auto* b = deco->As<BindingDecoration>()) {
} else if (auto* s = deco->As<SetDecoration>()) { binding = b;
set = s; }
if (auto* s = deco->As<SetDecoration>()) {
set = s;
}
}
if (binding == nullptr || set == nullptr) {
continue;
} }
}
if (binding == nullptr || set == nullptr) {
continue;
}
ret.push_back({var, BindingInfo{binding, set}}); ret.push_back({var, BindingInfo{binding, set}});
}
} }
return ret; return ret;
} }
@ -321,29 +325,34 @@ Function::ReferencedSampledTextureVariablesImpl(bool multisampled) const {
for (auto* var : referenced_module_variables()) { for (auto* var : referenced_module_variables()) {
auto* unwrapped_type = var->type()->UnwrapIfNeeded(); auto* unwrapped_type = var->type()->UnwrapIfNeeded();
if (!var->Is<DecoratedVariable>() || !unwrapped_type->Is<type::Texture>()) { auto* texture = unwrapped_type->As<type::Texture>();
if (texture == nullptr) {
continue; continue;
} }
if ((multisampled && !unwrapped_type->Is<type::MultisampledTexture>()) || auto is_multisampled = texture->Is<type::MultisampledTexture>();
(!multisampled && !unwrapped_type->Is<type::SampledTexture>())) { auto is_sampled = texture->Is<type::SampledTexture>();
if ((multisampled && !is_multisampled) || (!multisampled && !is_sampled)) {
continue; continue;
} }
BindingDecoration* binding = nullptr; if (auto* decorated = var->As<DecoratedVariable>()) {
SetDecoration* set = nullptr; BindingDecoration* binding = nullptr;
for (auto* deco : var->As<DecoratedVariable>()->decorations()) { SetDecoration* set = nullptr;
if (auto* b = deco->As<BindingDecoration>()) { for (auto* deco : decorated->decorations()) {
binding = b; if (auto* b = deco->As<BindingDecoration>()) {
} else if (auto* s = deco->As<SetDecoration>()) { binding = b;
set = s; } else if (auto* s = deco->As<SetDecoration>()) {
set = s;
}
}
if (binding == nullptr || set == nullptr) {
continue;
} }
}
if (binding == nullptr || set == nullptr) {
continue;
}
ret.push_back({var, BindingInfo{binding, set}}); ret.push_back({var, BindingInfo{binding, set}});
}
} }
return ret; return ret;

View File

@ -74,17 +74,16 @@ bool Module::IsValid() const {
if (ty == nullptr) { if (ty == nullptr) {
return false; return false;
} }
if (ty->Is<type::Alias>()) { if (auto* alias = ty->As<type::Alias>()) {
auto* alias = ty->As<type::Alias>();
if (alias->type() == nullptr) { if (alias->type() == nullptr) {
return false; return false;
} }
if (alias->type()->Is<type::Struct>() && if (auto* str = alias->type()->As<type::Struct>()) {
alias->type()->As<type::Struct>()->name().empty()) { if (str->name().empty()) {
return false; return false;
}
} }
} else if (ty->Is<type::Struct>()) { } else if (auto* str = ty->As<type::Struct>()) {
auto* str = ty->As<type::Struct>();
if (str->name().empty()) { if (str->name().empty()) {
return false; return false;
} }
@ -109,14 +108,12 @@ std::string Module::to_str() const {
for (size_t i = 0; i < indent; ++i) { for (size_t i = 0; i < indent; ++i) {
out << " "; out << " ";
} }
if (ty->Is<type::Alias>()) { if (auto* alias = ty->As<type::Alias>()) {
auto* alias = ty->As<type::Alias>();
out << alias->name() << " -> " << alias->type()->type_name() << std::endl; out << alias->name() << " -> " << alias->type()->type_name() << std::endl;
if (alias->type()->Is<type::Struct>()) { if (auto* str = alias->type()->As<type::Struct>()) {
alias->type()->As<type::Struct>()->impl()->to_str(out, indent); str->impl()->to_str(out, indent);
} }
} else if (ty->Is<type::Struct>()) { } else if (auto* str = ty->As<type::Struct>()) {
auto* str = ty->As<type::Struct>();
out << str->name() << " "; out << str->name() << " ";
str->impl()->to_str(out, indent); str->impl()->to_str(out, indent);
} }

View File

@ -102,7 +102,7 @@ TEST_F(StorageTextureTest, F32) {
ASSERT_TRUE(td.Determine()) << td.error(); ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(s->Is<Texture>()); ASSERT_TRUE(s->Is<Texture>());
ASSERT_TRUE(s->Is<StorageTexture>()); ASSERT_TRUE(s->Is<StorageTexture>());
EXPECT_TRUE(s->As<Texture>()->As<StorageTexture>()->type()->Is<F32>()); EXPECT_TRUE(s->As<StorageTexture>()->type()->Is<F32>());
} }
TEST_F(StorageTextureTest, U32) { TEST_F(StorageTextureTest, U32) {
@ -130,7 +130,7 @@ TEST_F(StorageTextureTest, I32) {
ASSERT_TRUE(td.Determine()) << td.error(); ASSERT_TRUE(td.Determine()) << td.error();
ASSERT_TRUE(s->Is<Texture>()); ASSERT_TRUE(s->Is<Texture>());
ASSERT_TRUE(s->Is<StorageTexture>()); ASSERT_TRUE(s->Is<StorageTexture>());
EXPECT_TRUE(s->As<Texture>()->As<StorageTexture>()->type()->Is<I32>()); EXPECT_TRUE(s->As<StorageTexture>()->type()->Is<I32>());
} }
TEST_F(StorageTextureTest, MinBufferBindingSize) { TEST_F(StorageTextureTest, MinBufferBindingSize) {

View File

@ -42,8 +42,8 @@ Type::Type(Type&&) = default;
Type::~Type() = default; Type::~Type() = default;
Type* Type::UnwrapPtrIfNeeded() { Type* Type::UnwrapPtrIfNeeded() {
if (Is<Pointer>()) { if (auto* ptr = As<type::Pointer>()) {
return As<Pointer>()->type(); return ptr->type();
} }
return this; return this;
} }
@ -51,10 +51,10 @@ Type* Type::UnwrapPtrIfNeeded() {
Type* Type::UnwrapIfNeeded() { Type* Type::UnwrapIfNeeded() {
auto* where = this; auto* where = this;
while (true) { while (true) {
if (where->Is<Alias>()) { if (auto* alias = where->As<type::Alias>()) {
where = where->As<Alias>()->type(); where = alias->type();
} else if (where->Is<AccessControl>()) { } else if (auto* access = where->As<type::AccessControl>()) {
where = where->As<AccessControl>()->type(); where = access->type();
} else { } else {
break; break;
} }

View File

@ -124,44 +124,45 @@ std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
} }
auto* expression = var->constructor(); auto* expression = var->constructor();
if (!expression->Is<ast::ConstructorExpression>()) {
// This is invalid WGSL, but handling gracefully.
result[constant_id] = Scalar();
continue;
}
auto* constructor = expression->As<ast::ConstructorExpression>(); auto* constructor = expression->As<ast::ConstructorExpression>();
if (!constructor->Is<ast::ScalarConstructorExpression>()) { if (constructor == nullptr) {
// This is invalid WGSL, but handling gracefully. // This is invalid WGSL, but handling gracefully.
result[constant_id] = Scalar(); result[constant_id] = Scalar();
continue; continue;
} }
auto* literal = auto* scalar_constructor =
constructor->As<ast::ScalarConstructorExpression>()->literal(); constructor->As<ast::ScalarConstructorExpression>();
if (scalar_constructor == nullptr) {
// This is invalid WGSL, but handling gracefully.
result[constant_id] = Scalar();
continue;
}
auto* literal = scalar_constructor->literal();
if (!literal) { if (!literal) {
// This is invalid WGSL, but handling gracefully. // This is invalid WGSL, but handling gracefully.
result[constant_id] = Scalar(); result[constant_id] = Scalar();
continue; continue;
} }
if (literal->Is<ast::BoolLiteral>()) { if (auto* l = literal->As<ast::BoolLiteral>()) {
result[constant_id] = Scalar(literal->As<ast::BoolLiteral>()->IsTrue()); result[constant_id] = Scalar(l->IsTrue());
continue; continue;
} }
if (literal->Is<ast::UintLiteral>()) { if (auto* l = literal->As<ast::UintLiteral>()) {
result[constant_id] = Scalar(literal->As<ast::UintLiteral>()->value()); result[constant_id] = Scalar(l->value());
continue; continue;
} }
if (literal->Is<ast::SintLiteral>()) { if (auto* l = literal->As<ast::SintLiteral>()) {
result[constant_id] = Scalar(literal->As<ast::SintLiteral>()->value()); result[constant_id] = Scalar(l->value());
continue; continue;
} }
if (literal->Is<ast::FloatLiteral>()) { if (auto* l = literal->As<ast::FloatLiteral>()) {
result[constant_id] = Scalar(literal->As<ast::FloatLiteral>()->value()); result[constant_id] = Scalar(l->value());
continue; continue;
} }
@ -190,11 +191,12 @@ std::vector<ResourceBinding> Inspector::GetUniformBufferResourceBindings(
} }
auto* unwrapped_type = var->type()->UnwrapIfNeeded(); auto* unwrapped_type = var->type()->UnwrapIfNeeded();
if (!unwrapped_type->Is<ast::type::Struct>()) { auto* str = unwrapped_type->As<ast::type::Struct>();
if (str == nullptr) {
continue; continue;
} }
if (!unwrapped_type->As<ast::type::Struct>()->IsBlockDecorated()) { if (!str->IsBlockDecorated()) {
continue; continue;
} }
@ -307,11 +309,12 @@ std::vector<ResourceBinding> Inspector::GetStorageBufferResourceBindingsImpl(
ast::Variable* var = nullptr; ast::Variable* var = nullptr;
ast::Function::BindingInfo binding_info; ast::Function::BindingInfo binding_info;
std::tie(var, binding_info) = rsv; std::tie(var, binding_info) = rsv;
if (!var->type()->Is<ast::type::AccessControl>()) {
auto* ac_type = var->type()->As<ast::type::AccessControl>();
if (ac_type == nullptr) {
continue; continue;
} }
auto* ac_type = var->type()->As<ast::type::AccessControl>();
if (read_only != ac_type->IsReadOnly()) { if (read_only != ac_type->IsReadOnly()) {
continue; continue;
} }
@ -392,12 +395,12 @@ std::vector<ResourceBinding> Inspector::GetSampledTextureResourceBindingsImpl(
->UnwrapIfNeeded(); ->UnwrapIfNeeded();
} }
if (base_type->Is<ast::type::Array>()) { if (auto* at = base_type->As<ast::type::Array>()) {
base_type = base_type->As<ast::type::Array>()->type(); base_type = at->type();
} else if (base_type->Is<ast::type::Matrix>()) { } else if (auto* mt = base_type->As<ast::type::Matrix>()) {
base_type = base_type->As<ast::type::Matrix>()->type(); base_type = mt->type();
} else if (base_type->Is<ast::type::Vector>()) { } else if (auto* vt = base_type->As<ast::type::Vector>()) {
base_type = base_type->As<ast::type::Vector>()->type(); base_type = vt->type();
} }
if (base_type->Is<ast::type::F32>()) { if (base_type->Is<ast::type::F32>()) {

View File

@ -2267,11 +2267,11 @@ bool FunctionEmitter::EmitContinuingStart(const Construct* construct) {
// A continue construct has the same depth as its associated loop // A continue construct has the same depth as its associated loop
// construct. Start a continue construct. // construct. Start a continue construct.
auto* loop_candidate = LastStatement(); auto* loop_candidate = LastStatement();
if (!loop_candidate->Is<ast::LoopStatement>()) { auto* loop = loop_candidate->As<ast::LoopStatement>();
if (loop == nullptr) {
return Fail() << "internal error: starting continue construct, " return Fail() << "internal error: starting continue construct, "
"expected loop on top of stack"; "expected loop on top of stack";
} }
auto* loop = loop_candidate->As<ast::LoopStatement>();
PushNewStatementBlock( PushNewStatementBlock(
construct, construct->end_id, construct, construct->end_id,
[loop](StatementBlock* s) { loop->set_continuing(s->statements_); }); [loop](StatementBlock* s) { loop->set_continuing(s->statements_); });
@ -3268,10 +3268,10 @@ bool FunctionEmitter::RegisterLocallyDefinedValues() {
const auto* type = type_mgr_->GetType(inst.type_id()); const auto* type = type_mgr_->GetType(inst.type_id());
if (type) { if (type) {
if (type->AsPointer()) { if (type->AsPointer()) {
const auto* ast_type = parser_impl_.ConvertType(inst.type_id()); if (const auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
if (ast_type && ast_type->As<ast::type::Pointer>()) { if (auto* ptr = ast_type->As<ast::type::Pointer>()) {
info->storage_class = info->storage_class = ptr->storage_class();
ast_type->As<ast::type::Pointer>()->storage_class(); }
} }
switch (inst.opcode()) { switch (inst.opcode()) {
case SpvOpUndef: case SpvOpUndef:
@ -3322,10 +3322,9 @@ ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) {
ast::type::Type* FunctionEmitter::RemapStorageClass(ast::type::Type* type, ast::type::Type* FunctionEmitter::RemapStorageClass(ast::type::Type* type,
uint32_t result_id) { uint32_t result_id) {
if (type->Is<ast::type::Pointer>()) { if (const auto* ast_ptr_type = type->As<ast::type::Pointer>()) {
// Remap an old-style storage buffer pointer to a new-style storage // Remap an old-style storage buffer pointer to a new-style storage
// buffer pointer. // buffer pointer.
const auto* ast_ptr_type = type->As<ast::type::Pointer>();
const auto sc = GetStorageClassForPointerValue(result_id); const auto sc = GetStorageClassForPointerValue(result_id);
if (ast_ptr_type->storage_class() != sc) { if (ast_ptr_type->storage_class() != sc) {
return parser_impl_.get_module().create<ast::type::Pointer>( return parser_impl_.get_module().create<ast::type::Pointer>(

View File

@ -1349,8 +1349,7 @@ ast::Expression* ParserImpl::MakeNullValue(ast::type::Type* type) {
return create<ast::ScalarConstructorExpression>( return create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(type, 0.0f)); create<ast::FloatLiteral>(type, 0.0f));
} }
if (type->Is<ast::type::Vector>()) { if (const auto* vec_ty = type->As<ast::type::Vector>()) {
const auto* vec_ty = type->As<ast::type::Vector>();
ast::ExpressionList ast_components; ast::ExpressionList ast_components;
for (size_t i = 0; i < vec_ty->size(); ++i) { for (size_t i = 0; i < vec_ty->size(); ++i) {
ast_components.emplace_back(MakeNullValue(vec_ty->type())); ast_components.emplace_back(MakeNullValue(vec_ty->type()));
@ -1358,8 +1357,7 @@ ast::Expression* ParserImpl::MakeNullValue(ast::type::Type* type) {
return create<ast::TypeConstructorExpression>(type, return create<ast::TypeConstructorExpression>(type,
std::move(ast_components)); std::move(ast_components));
} }
if (type->Is<ast::type::Matrix>()) { if (const auto* mat_ty = type->As<ast::type::Matrix>()) {
const auto* mat_ty = type->As<ast::type::Matrix>();
// Matrix components are columns // Matrix components are columns
auto* column_ty = auto* column_ty =
ast_module_.create<ast::type::Vector>(mat_ty->type(), mat_ty->rows()); ast_module_.create<ast::type::Vector>(mat_ty->type(), mat_ty->rows());
@ -1370,8 +1368,7 @@ ast::Expression* ParserImpl::MakeNullValue(ast::type::Type* type) {
return create<ast::TypeConstructorExpression>(type, return create<ast::TypeConstructorExpression>(type,
std::move(ast_components)); std::move(ast_components));
} }
if (type->Is<ast::type::Array>()) { if (auto* arr_ty = type->As<ast::type::Array>()) {
auto* arr_ty = type->As<ast::type::Array>();
ast::ExpressionList ast_components; ast::ExpressionList ast_components;
for (size_t i = 0; i < arr_ty->size(); ++i) { for (size_t i = 0; i < arr_ty->size(); ++i) {
ast_components.emplace_back(MakeNullValue(arr_ty->type())); ast_components.emplace_back(MakeNullValue(arr_ty->type()));
@ -1379,8 +1376,7 @@ ast::Expression* ParserImpl::MakeNullValue(ast::type::Type* type) {
return create<ast::TypeConstructorExpression>(original_type, return create<ast::TypeConstructorExpression>(original_type,
std::move(ast_components)); std::move(ast_components));
} }
if (type->Is<ast::type::Struct>()) { if (auto* struct_ty = type->As<ast::type::Struct>()) {
auto* struct_ty = type->As<ast::type::Struct>();
ast::ExpressionList ast_components; ast::ExpressionList ast_components;
for (auto* member : struct_ty->impl()->members()) { for (auto* member : struct_ty->impl()->members()) {
ast_components.emplace_back(MakeNullValue(member->type())); ast_components.emplace_back(MakeNullValue(member->type()));

View File

@ -2941,8 +2941,8 @@ std::vector<T*> ParserImpl::take_decorations(ast::DecorationList& in) {
std::vector<T*> out; std::vector<T*> out;
out.reserve(in.size()); out.reserve(in.size());
for (auto* deco : in) { for (auto* deco : in) {
if (deco->Is<T>()) { if (auto* t = deco->As<T>()) {
out.emplace_back(deco->As<T>()); out.emplace_back(t);
} else { } else {
remaining.emplace_back(deco); remaining.emplace_back(deco);
} }

View File

@ -37,7 +37,7 @@ TEST_F(ParserImplTest, DepthTextureType_2d) {
EXPECT_FALSE(t.errored); EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr); ASSERT_NE(t.value, nullptr);
ASSERT_TRUE(t->Is<ast::type::Texture>()); ASSERT_TRUE(t->Is<ast::type::Texture>());
ASSERT_TRUE(t->As<ast::type::Texture>()->Is<ast::type::DepthTexture>()); ASSERT_TRUE(t->Is<ast::type::DepthTexture>());
EXPECT_EQ(t->As<ast::type::Texture>()->dim(), EXPECT_EQ(t->As<ast::type::Texture>()->dim(),
ast::type::TextureDimension::k2d); ast::type::TextureDimension::k2d);
EXPECT_FALSE(p->has_error()); EXPECT_FALSE(p->has_error());
@ -50,7 +50,7 @@ TEST_F(ParserImplTest, DepthTextureType_2dArray) {
EXPECT_FALSE(t.errored); EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr); ASSERT_NE(t.value, nullptr);
ASSERT_TRUE(t->Is<ast::type::Texture>()); ASSERT_TRUE(t->Is<ast::type::Texture>());
ASSERT_TRUE(t->As<ast::type::Texture>()->Is<ast::type::DepthTexture>()); ASSERT_TRUE(t->Is<ast::type::DepthTexture>());
EXPECT_EQ(t->As<ast::type::Texture>()->dim(), EXPECT_EQ(t->As<ast::type::Texture>()->dim(),
ast::type::TextureDimension::k2dArray); ast::type::TextureDimension::k2dArray);
EXPECT_FALSE(p->has_error()); EXPECT_FALSE(p->has_error());
@ -63,7 +63,7 @@ TEST_F(ParserImplTest, DepthTextureType_Cube) {
EXPECT_FALSE(t.errored); EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr); ASSERT_NE(t.value, nullptr);
ASSERT_TRUE(t->Is<ast::type::Texture>()); ASSERT_TRUE(t->Is<ast::type::Texture>());
ASSERT_TRUE(t->As<ast::type::Texture>()->Is<ast::type::DepthTexture>()); ASSERT_TRUE(t->Is<ast::type::DepthTexture>());
EXPECT_EQ(t->As<ast::type::Texture>()->dim(), EXPECT_EQ(t->As<ast::type::Texture>()->dim(),
ast::type::TextureDimension::kCube); ast::type::TextureDimension::kCube);
EXPECT_FALSE(p->has_error()); EXPECT_FALSE(p->has_error());
@ -76,7 +76,7 @@ TEST_F(ParserImplTest, DepthTextureType_CubeArray) {
EXPECT_FALSE(t.errored); EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr); ASSERT_NE(t.value, nullptr);
ASSERT_TRUE(t->Is<ast::type::Texture>()); ASSERT_TRUE(t->Is<ast::type::Texture>());
ASSERT_TRUE(t->As<ast::type::Texture>()->Is<ast::type::DepthTexture>()); ASSERT_TRUE(t->Is<ast::type::DepthTexture>());
EXPECT_EQ(t->As<ast::type::Texture>()->dim(), EXPECT_EQ(t->As<ast::type::Texture>()->dim(),
ast::type::TextureDimension::kCubeArray); ast::type::TextureDimension::kCubeArray);
EXPECT_FALSE(p->has_error()); EXPECT_FALSE(p->has_error());

View File

@ -213,20 +213,20 @@ bool BoundArrayAccessorsTransform::ProcessAccessExpression(
// Scalar constructor we can re-write the value to be within bounds. // Scalar constructor we can re-write the value to be within bounds.
if (auto* c = expr->idx_expr()->As<ast::ScalarConstructorExpression>()) { if (auto* c = expr->idx_expr()->As<ast::ScalarConstructorExpression>()) {
auto* lit = c->literal(); auto* lit = c->literal();
if (lit->Is<ast::SintLiteral>()) { if (auto* sint = lit->As<ast::SintLiteral>()) {
int32_t val = lit->As<ast::SintLiteral>()->value(); int32_t val = sint->value();
if (val < 0) { if (val < 0) {
val = 0; val = 0;
} else if (val >= int32_t(size)) { } else if (val >= int32_t(size)) {
val = int32_t(size) - 1; val = int32_t(size) - 1;
} }
lit->As<ast::SintLiteral>()->set_value(val); sint->set_value(val);
} else if (lit->Is<ast::UintLiteral>()) { } else if (auto* uint = lit->As<ast::UintLiteral>()) {
uint32_t val = lit->As<ast::UintLiteral>()->value(); uint32_t val = uint->value();
if (val >= size - 1) { if (val >= size - 1) {
val = size - 1; val = size - 1;
} }
lit->As<ast::UintLiteral>()->set_value(val); uint->set_value(val);
} else { } else {
error_ = "unknown scalar constructor type for accessor"; error_ = "unknown scalar constructor type for accessor";
return false; return false;

View File

@ -128,11 +128,11 @@ void VertexPullingTransform::FindOrInsertVertexIndexIfUsed() {
} }
for (auto* d : v->As<ast::DecoratedVariable>()->decorations()) { for (auto* d : v->As<ast::DecoratedVariable>()->decorations()) {
if (d->Is<ast::BuiltinDecoration>() && if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
d->As<ast::BuiltinDecoration>()->value() == if (builtin->value() == ast::Builtin::kVertexIdx) {
ast::Builtin::kVertexIdx) { vertex_index_name_ = v->name();
vertex_index_name_ = v->name(); return;
return; }
} }
} }
} }
@ -172,11 +172,11 @@ void VertexPullingTransform::FindOrInsertInstanceIndexIfUsed() {
} }
for (auto* d : v->As<ast::DecoratedVariable>()->decorations()) { for (auto* d : v->As<ast::DecoratedVariable>()->decorations()) {
if (d->Is<ast::BuiltinDecoration>() && if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
d->As<ast::BuiltinDecoration>()->value() == if (builtin->value() == ast::Builtin::kInstanceIdx) {
ast::Builtin::kInstanceIdx) { instance_index_name_ = v->name();
instance_index_name_ = v->name(); return;
return; }
} }
} }
} }

View File

@ -88,15 +88,13 @@ void TypeDeterminer::set_referenced_from_function_if_needed(
bool TypeDeterminer::Determine() { bool TypeDeterminer::Determine() {
for (auto& iter : mod_->types()) { for (auto& iter : mod_->types()) {
auto& type = iter.second; auto& type = iter.second;
if (!type->Is<ast::type::Texture>() || if (auto* storage = type->As<ast::type::StorageTexture>()) {
!type->Is<ast::type::StorageTexture>()) { if (!DetermineStorageTextureSubtype(storage)) {
continue; set_error(Source{},
} "unable to determine storage texture subtype for: " +
if (!DetermineStorageTextureSubtype( type->type_name());
type->As<ast::type::StorageTexture>())) { return false;
set_error(Source{}, "unable to determine storage texture subtype for: " + }
type->type_name());
return false;
} }
} }
@ -180,11 +178,12 @@ bool TypeDeterminer::DetermineStatements(const ast::BlockStatement* stmts) {
} }
bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) { bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) {
if (!stmt->Is<ast::VariableDeclStatement>()) { auto* var_decl = stmt->As<ast::VariableDeclStatement>();
if (var_decl == nullptr) {
return true; return true;
} }
auto* var = stmt->As<ast::VariableDeclStatement>()->variable(); auto* var = var_decl->variable();
// Nothing to do for const // Nothing to do for const
if (var->is_const()) { if (var->is_const()) {
return true; return true;
@ -330,13 +329,12 @@ bool TypeDeterminer::DetermineArrayAccessor(
auto* res = expr->array()->result_type(); auto* res = expr->array()->result_type();
auto* parent_type = res->UnwrapAll(); auto* parent_type = res->UnwrapAll();
ast::type::Type* ret = nullptr; ast::type::Type* ret = nullptr;
if (parent_type->Is<ast::type::Array>()) { if (auto* arr = parent_type->As<ast::type::Array>()) {
ret = parent_type->As<ast::type::Array>()->type(); ret = arr->type();
} else if (parent_type->Is<ast::type::Vector>()) { } else if (auto* vec = parent_type->As<ast::type::Vector>()) {
ret = parent_type->As<ast::type::Vector>()->type(); ret = vec->type();
} else if (parent_type->Is<ast::type::Matrix>()) { } else if (auto* mat = parent_type->As<ast::type::Matrix>()) {
auto* m = parent_type->As<ast::type::Matrix>(); ret = mod_->create<ast::type::Vector>(mat->type(), mat->rows());
ret = mod_->create<ast::type::Vector>(m->type(), m->rows());
} else { } else {
set_error(expr->source(), "invalid parent type (" + set_error(expr->source(), "invalid parent type (" +
parent_type->type_name() + parent_type->type_name() +
@ -345,15 +343,15 @@ bool TypeDeterminer::DetermineArrayAccessor(
} }
// If we're extracting from a pointer, we return a pointer. // If we're extracting from a pointer, we return a pointer.
if (res->Is<ast::type::Pointer>()) { if (auto* ptr = res->As<ast::type::Pointer>()) {
ret = mod_->create<ast::type::Pointer>( ret = mod_->create<ast::type::Pointer>(ret, ptr->storage_class());
ret, res->As<ast::type::Pointer>()->storage_class()); } else if (auto* arr = parent_type->As<ast::type::Array>()) {
} else if (parent_type->Is<ast::type::Array>() && if (!arr->type()->is_scalar()) {
!parent_type->As<ast::type::Array>()->type()->is_scalar()) { // If we extract a non-scalar from an array then we also get a pointer. We
// If we extract a non-scalar from an array then we also get a pointer. We // will generate a Function storage class variable to store this
// will generate a Function storage class variable to store this // into.
// into. ret = mod_->create<ast::type::Pointer>(ret, ast::StorageClass::kFunction);
ret = mod_->create<ast::type::Pointer>(ret, ast::StorageClass::kFunction); }
} }
expr->set_result_type(ret); expr->set_result_type(ret);
@ -532,9 +530,9 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
auto* bool_type = mod_->create<ast::type::Bool>(); auto* bool_type = mod_->create<ast::type::Bool>();
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
if (param_type->Is<ast::type::Vector>()) { if (auto* vec = param_type->As<ast::type::Vector>()) {
expr->func()->set_result_type(mod_->create<ast::type::Vector>( expr->func()->set_result_type(
bool_type, param_type->As<ast::type::Vector>()->size())); mod_->create<ast::type::Vector>(bool_type, vec->size()));
} else { } else {
expr->func()->set_result_type(bool_type); expr->func()->set_result_type(bool_type);
} }
@ -662,20 +660,13 @@ bool TypeDeterminer::DetermineIntrinsic(ast::IdentifierExpression* ident,
return true; return true;
} }
if (!texture->Is<ast::type::StorageTexture>() &&
!(texture->Is<ast::type::SampledTexture>() ||
texture->Is<ast::type::MultisampledTexture>())) {
set_error(expr->source(), "invalid texture for " + ident->name());
return false;
}
ast::type::Type* type = nullptr; ast::type::Type* type = nullptr;
if (texture->Is<ast::type::StorageTexture>()) { if (auto* storage = texture->As<ast::type::StorageTexture>()) {
type = texture->As<ast::type::StorageTexture>()->type(); type = storage->type();
} else if (texture->Is<ast::type::SampledTexture>()) { } else if (auto* sampled = texture->As<ast::type::SampledTexture>()) {
type = texture->As<ast::type::SampledTexture>()->type(); type = sampled->type();
} else if (texture->Is<ast::type::MultisampledTexture>()) { } else if (auto* msampled = texture->As<ast::type::MultisampledTexture>()) {
type = texture->As<ast::type::MultisampledTexture>()->type(); type = msampled->type();
} else { } else {
set_error(expr->source(), "unknown texture type for texture sampling"); set_error(expr->source(), "unknown texture type for texture sampling");
return false; return false;
@ -1030,8 +1021,8 @@ bool TypeDeterminer::DetermineMemberAccessor(
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
ast::type::Type* ret = nullptr; ast::type::Type* ret = nullptr;
if (data_type->Is<ast::type::Struct>()) { if (auto* ty = data_type->As<ast::type::Struct>()) {
auto* strct = data_type->As<ast::type::Struct>()->impl(); auto* strct = ty->impl();
auto name = expr->member()->name(); auto name = expr->member()->name();
for (auto* member : strct->members()) { for (auto* member : strct->members()) {
@ -1047,21 +1038,17 @@ bool TypeDeterminer::DetermineMemberAccessor(
} }
// If we're extracting from a pointer, we return a pointer. // If we're extracting from a pointer, we return a pointer.
if (res->Is<ast::type::Pointer>()) { if (auto* ptr = res->As<ast::type::Pointer>()) {
ret = mod_->create<ast::type::Pointer>( ret = mod_->create<ast::type::Pointer>(ret, ptr->storage_class());
ret, res->As<ast::type::Pointer>()->storage_class());
} }
} else if (data_type->Is<ast::type::Vector>()) { } else if (auto* vec = data_type->As<ast::type::Vector>()) {
auto* vec = data_type->As<ast::type::Vector>();
auto size = expr->member()->name().size(); auto size = expr->member()->name().size();
if (size == 1) { if (size == 1) {
// A single element swizzle is just the type of the vector. // A single element swizzle is just the type of the vector.
ret = vec->type(); ret = vec->type();
// If we're extracting from a pointer, we return a pointer. // If we're extracting from a pointer, we return a pointer.
if (res->Is<ast::type::Pointer>()) { if (auto* ptr = res->As<ast::type::Pointer>()) {
ret = mod_->create<ast::type::Pointer>( ret = mod_->create<ast::type::Pointer>(ret, ptr->storage_class());
ret, res->As<ast::type::Pointer>()->storage_class());
} }
} else { } else {
// The vector will have a number of components equal to the length of the // The vector will have a number of components equal to the length of the
@ -1100,9 +1087,9 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) { expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto* bool_type = mod_->create<ast::type::Bool>(); auto* bool_type = mod_->create<ast::type::Bool>();
auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded(); auto* param_type = expr->lhs()->result_type()->UnwrapPtrIfNeeded();
if (param_type->Is<ast::type::Vector>()) { if (auto* vec = param_type->As<ast::type::Vector>()) {
expr->set_result_type(mod_->create<ast::type::Vector>( expr->set_result_type(
bool_type, param_type->As<ast::type::Vector>()->size())); mod_->create<ast::type::Vector>(bool_type, vec->size()));
} else { } else {
expr->set_result_type(bool_type); expr->set_result_type(bool_type);
} }
@ -1114,36 +1101,31 @@ bool TypeDeterminer::DetermineBinary(ast::BinaryExpression* expr) {
// Note, the ordering here matters. The later checks depend on the prior // Note, the ordering here matters. The later checks depend on the prior
// checks having been done. // checks having been done.
if (lhs_type->Is<ast::type::Matrix>() && auto* lhs_mat = lhs_type->As<ast::type::Matrix>();
rhs_type->Is<ast::type::Matrix>()) { auto* rhs_mat = rhs_type->As<ast::type::Matrix>();
auto* lhs_vec = lhs_type->As<ast::type::Vector>();
auto* rhs_vec = rhs_type->As<ast::type::Vector>();
if (lhs_mat && rhs_mat) {
expr->set_result_type(mod_->create<ast::type::Matrix>( expr->set_result_type(mod_->create<ast::type::Matrix>(
lhs_type->As<ast::type::Matrix>()->type(), lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns()));
lhs_type->As<ast::type::Matrix>()->rows(), } else if (lhs_mat && rhs_vec) {
rhs_type->As<ast::type::Matrix>()->columns()));
} else if (lhs_type->Is<ast::type::Matrix>() &&
rhs_type->Is<ast::type::Vector>()) {
auto* mat = lhs_type->As<ast::type::Matrix>();
expr->set_result_type( expr->set_result_type(
mod_->create<ast::type::Vector>(mat->type(), mat->rows())); mod_->create<ast::type::Vector>(lhs_mat->type(), lhs_mat->rows()));
} else if (lhs_type->Is<ast::type::Vector>() && } else if (lhs_vec && rhs_mat) {
rhs_type->Is<ast::type::Matrix>()) {
auto* mat = rhs_type->As<ast::type::Matrix>();
expr->set_result_type( expr->set_result_type(
mod_->create<ast::type::Vector>(mat->type(), mat->columns())); mod_->create<ast::type::Vector>(rhs_mat->type(), rhs_mat->columns()));
} else if (lhs_type->Is<ast::type::Matrix>()) { } else if (lhs_mat) {
// matrix * scalar // matrix * scalar
expr->set_result_type(lhs_type); expr->set_result_type(lhs_type);
} else if (rhs_type->Is<ast::type::Matrix>()) { } else if (rhs_mat) {
// scalar * matrix // scalar * matrix
expr->set_result_type(rhs_type); expr->set_result_type(rhs_type);
} else if (lhs_type->Is<ast::type::Vector>() && } else if (lhs_vec && rhs_vec) {
rhs_type->Is<ast::type::Vector>()) {
expr->set_result_type(lhs_type); expr->set_result_type(lhs_type);
} else if (lhs_type->Is<ast::type::Vector>()) { } else if (lhs_vec) {
// Vector * scalar // Vector * scalar
expr->set_result_type(lhs_type); expr->set_result_type(lhs_type);
} else if (rhs_type->Is<ast::type::Vector>()) { } else if (rhs_vec) {
// Scalar * vector // Scalar * vector
expr->set_result_type(rhs_type); expr->set_result_type(rhs_type);
} else { } else {

View File

@ -85,11 +85,9 @@ bool ValidatorImpl::Validate(const ast::Module* module) {
bool ValidatorImpl::ValidateConstructedTypes( bool ValidatorImpl::ValidateConstructedTypes(
const std::vector<ast::type::Type*>& constructed_types) { const std::vector<ast::type::Type*>& constructed_types) {
for (auto* const ct : constructed_types) { for (auto* const ct : constructed_types) {
if (ct->Is<ast::type::Struct>()) { if (auto* st = ct->As<ast::type::Struct>()) {
auto* st = ct->As<ast::type::Struct>();
for (auto* member : st->impl()->members()) { for (auto* member : st->impl()->members()) {
if (member->type()->UnwrapAll()->Is<ast::type::Array>()) { if (auto* r = member->type()->UnwrapAll()->As<ast::type::Array>()) {
auto* r = member->type()->UnwrapAll()->As<ast::type::Array>();
if (r->IsRuntimeArray()) { if (r->IsRuntimeArray()) {
if (member != st->impl()->members().back()) { if (member != st->impl()->members().back()) {
add_error(member->source(), "v-0015", add_error(member->source(), "v-0015",
@ -265,12 +263,9 @@ bool ValidatorImpl::ValidateDeclStatement(
return false; return false;
} }
variable_stack_.set(name, decl->variable()); variable_stack_.set(name, decl->variable());
if (decl->variable()->type()->UnwrapAll()->Is<ast::type::Array>()) { if (auto* arr =
if (decl->variable() decl->variable()->type()->UnwrapAll()->As<ast::type::Array>()) {
->type() if (arr->IsRuntimeArray()) {
->UnwrapAll()
->As<ast::type::Array>()
->IsRuntimeArray()) {
add_error(decl->source(), "v-0015", add_error(decl->source(), "v-0015",
"runtime arrays may only appear as the last " "runtime arrays may only appear as the last "
"member of a struct: '" + "member of a struct: '" +
@ -317,7 +312,7 @@ bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) {
} }
auto* cond_type = s->condition()->result_type()->UnwrapAll(); auto* cond_type = s->condition()->result_type()->UnwrapAll();
if (!(cond_type->Is<ast::type::I32>() || cond_type->Is<ast::type::U32>())) { if (!cond_type->is_integer_scalar()) {
add_error(s->condition()->source(), "v-0025", add_error(s->condition()->source(), "v-0025",
"switch statement selector expression must be of a " "switch statement selector expression must be of a "
"scalar integer type"); "scalar integer type");

View File

@ -219,13 +219,11 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out,
const ast::type::Type* ty) { const ast::type::Type* ty) {
make_indent(out); make_indent(out);
if (ty->Is<ast::type::Alias>()) { if (auto* alias = ty->As<ast::type::Alias>()) {
auto* alias = ty->As<ast::type::Alias>();
// HLSL typedef is for intrinsic types only. For an alias'd struct, // HLSL typedef is for intrinsic types only. For an alias'd struct,
// generate a secondary struct with the new name. // generate a secondary struct with the new name.
if (alias->type()->Is<ast::type::Struct>()) { if (auto* str = alias->type()->As<ast::type::Struct>()) {
if (!EmitStructType(out, alias->type()->As<ast::type::Struct>(), if (!EmitStructType(out, str, alias->name())) {
alias->name())) {
return false; return false;
} }
return true; return true;
@ -235,8 +233,7 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out,
return false; return false;
} }
out << " " << namer_.NameFor(alias->name()) << ";" << std::endl; out << " " << namer_.NameFor(alias->name()) << ";" << std::endl;
} else if (ty->Is<ast::type::Struct>()) { } else if (auto* str = ty->As<ast::type::Struct>()) {
auto* str = ty->As<ast::type::Struct>();
if (!EmitStructType(out, str, str->name())) { if (!EmitStructType(out, str, str->name())) {
return false; return false;
} }
@ -272,9 +269,7 @@ bool GeneratorImpl::EmitArrayAccessor(std::ostream& pre,
bool GeneratorImpl::EmitBitcast(std::ostream& pre, bool GeneratorImpl::EmitBitcast(std::ostream& pre,
std::ostream& out, std::ostream& out,
ast::BitcastExpression* expr) { ast::BitcastExpression* expr) {
if (!expr->type()->Is<ast::type::F32>() && if (!expr->type()->is_integer_scalar() && !expr->type()->is_float_scalar()) {
!expr->type()->Is<ast::type::I32>() &&
!expr->type()->Is<ast::type::U32>()) {
error_ = "Unable to do bitcast to type " + expr->type()->type_name(); error_ = "Unable to do bitcast to type " + expr->type()->type_name();
return false; return false;
} }
@ -1005,11 +1000,14 @@ bool GeneratorImpl::EmitExpression(std::ostream& pre,
} }
bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const { bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const {
return var->Is<ast::DecoratedVariable>() && if (auto* decorated = var->As<ast::DecoratedVariable>()) {
(var->As<ast::DecoratedVariable>()->HasLocationDecoration() || if (decorated->HasLocationDecoration() ||
var->As<ast::DecoratedVariable>()->HasBuiltinDecoration()) && decorated->HasBuiltinDecoration()) {
(var->storage_class() == ast::StorageClass::kInput || return var->storage_class() == ast::StorageClass::kInput ||
var->storage_class() == ast::StorageClass::kOutput); var->storage_class() == ast::StorageClass::kOutput;
}
}
return false;
} }
bool GeneratorImpl::EmitIdentifier(std::ostream&, bool GeneratorImpl::EmitIdentifier(std::ostream&,
@ -1298,9 +1296,7 @@ bool GeneratorImpl::EmitEntryPointData(
emitted_globals.insert(var->name()); emitted_globals.insert(var->name());
auto* type = var->type()->UnwrapIfNeeded(); auto* type = var->type()->UnwrapIfNeeded();
if (type->Is<ast::type::Struct>()) { if (auto* strct = type->As<ast::type::Struct>()) {
auto* strct = type->As<ast::type::Struct>();
out << "ConstantBuffer<" << strct->name() << "> " << var->name() out << "ConstantBuffer<" << strct->name() << "> " << var->name()
<< " : register(b" << binding->value() << ");" << std::endl; << " : register(b" << binding->value() << ");" << std::endl;
} else { } else {
@ -1340,11 +1336,11 @@ bool GeneratorImpl::EmitEntryPointData(
} }
emitted_globals.insert(var->name()); emitted_globals.insert(var->name());
if (!var->type()->Is<ast::type::AccessControl>()) { auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac == nullptr) {
error_ = "access control type required for storage buffer"; error_ = "access control type required for storage buffer";
return false; return false;
} }
auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac->IsReadWrite()) { if (ac->IsReadWrite()) {
out << "RW"; out << "RW";
@ -1538,14 +1534,14 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
} }
bool GeneratorImpl::EmitLiteral(std::ostream& out, ast::Literal* lit) { bool GeneratorImpl::EmitLiteral(std::ostream& out, ast::Literal* lit) {
if (lit->Is<ast::BoolLiteral>()) { if (auto* l = lit->As<ast::BoolLiteral>()) {
out << (lit->As<ast::BoolLiteral>()->IsTrue() ? "true" : "false"); out << (l->IsTrue() ? "true" : "false");
} else if (lit->Is<ast::FloatLiteral>()) { } else if (auto* fl = lit->As<ast::FloatLiteral>()) {
out << FloatToString(lit->As<ast::FloatLiteral>()->value()) << "f"; out << FloatToString(fl->value()) << "f";
} else if (lit->Is<ast::SintLiteral>()) { } else if (auto* sl = lit->As<ast::SintLiteral>()) {
out << lit->As<ast::SintLiteral>()->value(); out << sl->value();
} else if (lit->Is<ast::UintLiteral>()) { } else if (auto* ul = lit->As<ast::UintLiteral>()) {
out << lit->As<ast::UintLiteral>()->value() << "u"; out << ul->value() << "u";
} else { } else {
error_ = "unknown literal type"; error_ = "unknown literal type";
return false; return false;
@ -1562,10 +1558,9 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, ast::type::Type* type) {
out << "0"; out << "0";
} else if (type->Is<ast::type::U32>()) { } else if (type->Is<ast::type::U32>()) {
out << "0u"; out << "0u";
} else if (type->Is<ast::type::Vector>()) { } else if (auto* vec = type->As<ast::type::Vector>()) {
return EmitZeroValue(out, type->As<ast::type::Vector>()->type()); return EmitZeroValue(out, vec->type());
} else if (type->Is<ast::type::Matrix>()) { } else if (auto* mat = type->As<ast::type::Matrix>()) {
auto* mat = type->As<ast::type::Matrix>();
for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) { for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
if (i != 0) { if (i != 0) {
out << ", "; out << ", ";
@ -1630,31 +1625,32 @@ bool GeneratorImpl::EmitLoop(std::ostream& out, ast::LoopStatement* stmt) {
for (auto* s : *(stmt->body())) { for (auto* s : *(stmt->body())) {
// If we have a continuing block we've already emitted the variable // If we have a continuing block we've already emitted the variable
// declaration before the loop, so treat it as an assignment. // declaration before the loop, so treat it as an assignment.
auto* decl = s->As<ast::VariableDeclStatement>(); if (auto* decl = s->As<ast::VariableDeclStatement>()) {
if (decl != nullptr && stmt->has_continuing()) { if (stmt->has_continuing()) {
make_indent(out); make_indent(out);
auto* var = decl->variable(); auto* var = decl->variable();
std::ostringstream pre; std::ostringstream pre;
std::ostringstream constructor_out; std::ostringstream constructor_out;
if (var->constructor() != nullptr) { if (var->constructor() != nullptr) {
if (!EmitExpression(pre, constructor_out, var->constructor())) { if (!EmitExpression(pre, constructor_out, var->constructor())) {
return false; return false;
}
} }
} out << pre.str();
out << pre.str();
out << var->name() << " = "; out << var->name() << " = ";
if (var->constructor() != nullptr) { if (var->constructor() != nullptr) {
out << constructor_out.str(); out << constructor_out.str();
} else { } else {
if (!EmitZeroValue(out, var->type())) { if (!EmitZeroValue(out, var->type())) {
return false; return false;
}
} }
out << ";" << std::endl;
continue;
} }
out << ";" << std::endl;
continue;
} }
if (!EmitStatement(out, s)) { if (!EmitStatement(out, s)) {
@ -1692,8 +1688,8 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
first = false; first = false;
if (auto* mem = expr->As<ast::MemberAccessorExpression>()) { if (auto* mem = expr->As<ast::MemberAccessorExpression>()) {
auto* res_type = mem->structure()->result_type()->UnwrapAll(); auto* res_type = mem->structure()->result_type()->UnwrapAll();
if (res_type->Is<ast::type::Struct>()) { if (auto* str = res_type->As<ast::type::Struct>()) {
auto* str_type = res_type->As<ast::type::Struct>()->impl(); auto* str_type = str->impl();
auto* str_member = str_type->get_member(mem->member()->name()); auto* str_member = str_type->get_member(mem->member()->name());
if (!str_member->has_offset_decoration()) { if (!str_member->has_offset_decoration()) {
@ -1728,15 +1724,14 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
auto* ary_type = ary->array()->result_type()->UnwrapAll(); auto* ary_type = ary->array()->result_type()->UnwrapAll();
out << "("; out << "(";
if (ary_type->Is<ast::type::Array>()) { if (auto* arr = ary_type->As<ast::type::Array>()) {
out << ary_type->As<ast::type::Array>()->array_stride(); out << arr->array_stride();
} else if (ary_type->Is<ast::type::Vector>()) { } else if (ary_type->Is<ast::type::Vector>()) {
// TODO(dsinclair): This is a hack. Our vectors can only be f32, i32 // TODO(dsinclair): This is a hack. Our vectors can only be f32, i32
// or u32 which are all 4 bytes. When we get f16 or other types we'll // or u32 which are all 4 bytes. When we get f16 or other types we'll
// have to ask the type for the byte size. // have to ask the type for the byte size.
out << "4"; out << "4";
} else if (ary_type->Is<ast::type::Matrix>()) { } else if (auto* mat = ary_type->As<ast::type::Matrix>()) {
auto* mat = ary_type->As<ast::type::Matrix>();
if (mat->columns() == 2) { if (mat->columns() == 2) {
out << "8"; out << "8";
} else { } else {
@ -1777,12 +1772,10 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre,
bool is_store = rhs != nullptr; bool is_store = rhs != nullptr;
std::string access_method = is_store ? "Store" : "Load"; std::string access_method = is_store ? "Store" : "Load";
if (result_type->Is<ast::type::Vector>()) { if (auto* vec = result_type->As<ast::type::Vector>()) {
access_method += access_method += std::to_string(vec->size());
std::to_string(result_type->As<ast::type::Vector>()->size()); } else if (auto* mat = result_type->As<ast::type::Matrix>()) {
} else if (result_type->Is<ast::type::Matrix>()) { access_method += std::to_string(mat->rows());
access_method +=
std::to_string(result_type->As<ast::type::Matrix>()->rows());
} }
// If we aren't storing then we need to put in the outer cast. // If we aren't storing then we need to put in the outer cast.
@ -1808,9 +1801,7 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre,
return false; return false;
} }
if (result_type->Is<ast::type::Matrix>()) { if (auto* mat = result_type->As<ast::type::Matrix>()) {
auto* mat = result_type->As<ast::type::Matrix>();
// TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed // TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed
// if we get matrixes of f16 or f64. // if we get matrixes of f16 or f64.
uint32_t stride = mat->rows() == 2 ? 8 : 16; uint32_t stride = mat->rows() == 2 ? 8 : 16;
@ -2041,24 +2032,21 @@ bool GeneratorImpl::EmitSwitch(std::ostream& out, ast::SwitchStatement* stmt) {
bool GeneratorImpl::EmitType(std::ostream& out, bool GeneratorImpl::EmitType(std::ostream& out,
ast::type::Type* type, ast::type::Type* type,
const std::string& name) { const std::string& name) {
if (type->Is<ast::type::Alias>()) { if (auto* alias = type->As<ast::type::Alias>()) {
auto* alias = type->As<ast::type::Alias>();
out << namer_.NameFor(alias->name()); out << namer_.NameFor(alias->name());
} else if (type->Is<ast::type::Array>()) { } else if (auto* ary = type->As<ast::type::Array>()) {
auto* ary = type->As<ast::type::Array>();
ast::type::Type* base_type = ary; ast::type::Type* base_type = ary;
std::vector<uint32_t> sizes; std::vector<uint32_t> sizes;
while (base_type->Is<ast::type::Array>()) { while (auto* arr = base_type->As<ast::type::Array>()) {
if (base_type->As<ast::type::Array>()->IsRuntimeArray()) { if (arr->IsRuntimeArray()) {
// TODO(dsinclair): Support runtime arrays // TODO(dsinclair): Support runtime arrays
// https://bugs.chromium.org/p/tint/issues/detail?id=185 // https://bugs.chromium.org/p/tint/issues/detail?id=185
error_ = "runtime array not supported yet."; error_ = "runtime array not supported yet.";
return false; return false;
} else { } else {
sizes.push_back(base_type->As<ast::type::Array>()->size()); sizes.push_back(arr->size());
} }
base_type = base_type->As<ast::type::Array>()->type(); base_type = arr->type();
} }
if (!EmitType(out, base_type, "")) { if (!EmitType(out, base_type, "")) {
return false; return false;
@ -2075,8 +2063,7 @@ bool GeneratorImpl::EmitType(std::ostream& out,
out << "float"; out << "float";
} else if (type->Is<ast::type::I32>()) { } else if (type->Is<ast::type::I32>()) {
out << "int"; out << "int";
} else if (type->Is<ast::type::Matrix>()) { } else if (auto* mat = type->As<ast::type::Matrix>()) {
auto* mat = type->As<ast::type::Matrix>();
if (!EmitType(out, mat->type(), "")) { if (!EmitType(out, mat->type(), "")) {
return false; return false;
} }
@ -2086,17 +2073,15 @@ bool GeneratorImpl::EmitType(std::ostream& out,
// https://bugs.chromium.org/p/tint/issues/detail?id=183 // https://bugs.chromium.org/p/tint/issues/detail?id=183
error_ = "pointers not supported in HLSL"; error_ = "pointers not supported in HLSL";
return false; return false;
} else if (type->Is<ast::type::Sampler>()) { } else if (auto* sampler = type->As<ast::type::Sampler>()) {
auto* sampler = type->As<ast::type::Sampler>();
out << "Sampler"; out << "Sampler";
if (sampler->IsComparison()) { if (sampler->IsComparison()) {
out << "Comparison"; out << "Comparison";
} }
out << "State"; out << "State";
} else if (type->Is<ast::type::Struct>()) { } else if (auto* str = type->As<ast::type::Struct>()) {
out << type->As<ast::type::Struct>()->name(); out << str->name();
} else if (type->Is<ast::type::Texture>()) { } else if (auto* tex = type->As<ast::type::Texture>()) {
auto* tex = type->As<ast::type::Texture>();
if (tex->Is<ast::type::StorageTexture>()) { if (tex->Is<ast::type::StorageTexture>()) {
out << "RW"; out << "RW";
} }
@ -2131,8 +2116,7 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} else if (type->Is<ast::type::U32>()) { } else if (type->Is<ast::type::U32>()) {
out << "uint"; out << "uint";
} else if (type->Is<ast::type::Vector>()) { } else if (auto* vec = type->As<ast::type::Vector>()) {
auto* vec = type->As<ast::type::Vector>();
auto size = vec->size(); auto size = vec->size();
if (vec->type()->Is<ast::type::F32>() && size >= 1 && size <= 4) { if (vec->type()->Is<ast::type::F32>() && size >= 1 && size <= 4) {
out << "float" << size; out << "float" << size;
@ -2250,8 +2234,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
const ast::Variable* var) { const ast::Variable* var) {
make_indent(out); make_indent(out);
if (var->Is<ast::DecoratedVariable>() && auto* decorated = var->As<ast::DecoratedVariable>();
!var->As<ast::DecoratedVariable>()->HasConstantIdDecoration()) { if (decorated != nullptr && !decorated->HasConstantIdDecoration()) {
error_ = "Decorated const values not valid"; error_ = "Decorated const values not valid";
return false; return false;
} }
@ -2269,9 +2253,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
out << pre.str(); out << pre.str();
} }
if (var->Is<ast::DecoratedVariable>() && if (decorated != nullptr && decorated->HasConstantIdDecoration()) {
var->As<ast::DecoratedVariable>()->HasConstantIdDecoration()) { auto const_id = decorated->constant_id();
auto const_id = var->As<ast::DecoratedVariable>()->constant_id();
out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl; out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;

View File

@ -186,11 +186,10 @@ uint32_t GeneratorImpl::calculate_largest_alignment(ast::type::Struct* type) {
} }
uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) { uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) {
if (type->Is<ast::type::Alias>()) { if (auto* alias = type->As<ast::type::Alias>()) {
return calculate_alignment_size(type->As<ast::type::Alias>()->type()); return calculate_alignment_size(alias->type());
} }
if (type->Is<ast::type::Array>()) { if (auto* ary = type->As<ast::type::Array>()) {
auto* ary = type->As<ast::type::Array>();
// TODO(dsinclair): Handle array stride and adjust for alignment. // TODO(dsinclair): Handle array stride and adjust for alignment.
uint32_t type_size = calculate_alignment_size(ary->type()); uint32_t type_size = calculate_alignment_size(ary->type());
return ary->size() * type_size; return ary->size() * type_size;
@ -205,15 +204,14 @@ uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) {
type->Is<ast::type::U32>()) { type->Is<ast::type::U32>()) {
return 4; return 4;
} }
if (type->Is<ast::type::Matrix>()) { if (auto* mat = type->As<ast::type::Matrix>()) {
auto* mat = type->As<ast::type::Matrix>();
// TODO(dsinclair): Handle MatrixStride // TODO(dsinclair): Handle MatrixStride
// https://github.com/gpuweb/gpuweb/issues/773 // https://github.com/gpuweb/gpuweb/issues/773
uint32_t type_size = calculate_alignment_size(mat->type()); uint32_t type_size = calculate_alignment_size(mat->type());
return mat->rows() * mat->columns() * type_size; return mat->rows() * mat->columns() * type_size;
} }
if (type->Is<ast::type::Struct>()) { if (auto* stct_ty = type->As<ast::type::Struct>()) {
auto* stct = type->As<ast::type::Struct>()->impl(); auto* stct = stct_ty->impl();
uint32_t count = 0; uint32_t count = 0;
uint32_t largest_alignment = 0; uint32_t largest_alignment = 0;
// Offset decorations in WGSL must be in increasing order. // Offset decorations in WGSL must be in increasing order.
@ -227,12 +225,11 @@ uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) {
if (align == 0) { if (align == 0) {
return 0; return 0;
} }
if (!mem->type()->Is<ast::type::Struct>()) { if (auto* str = mem->type()->As<ast::type::Struct>()) {
largest_alignment = std::max(largest_alignment, align); largest_alignment =
std::max(largest_alignment, calculate_largest_alignment(str));
} else { } else {
largest_alignment = std::max( largest_alignment = std::max(largest_alignment, align);
largest_alignment,
calculate_largest_alignment(mem->type()->As<ast::type::Struct>()));
} }
// Round up to the alignment size // Round up to the alignment size
@ -243,8 +240,7 @@ uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) {
count = adjust_for_alignment(count, largest_alignment); count = adjust_for_alignment(count, largest_alignment);
return count; return count;
} }
if (type->Is<ast::type::Vector>()) { if (auto* vec = type->As<ast::type::Vector>()) {
auto* vec = type->As<ast::type::Vector>();
uint32_t type_size = calculate_alignment_size(vec->type()); uint32_t type_size = calculate_alignment_size(vec->type());
if (vec->size() == 2) { if (vec->size() == 2) {
return 2 * type_size; return 2 * type_size;
@ -257,16 +253,14 @@ uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) {
bool GeneratorImpl::EmitConstructedType(const ast::type::Type* ty) { bool GeneratorImpl::EmitConstructedType(const ast::type::Type* ty) {
make_indent(); make_indent();
if (ty->Is<ast::type::Alias>()) { if (auto* alias = ty->As<ast::type::Alias>()) {
auto* alias = ty->As<ast::type::Alias>();
out_ << "typedef "; out_ << "typedef ";
if (!EmitType(alias->type(), "")) { if (!EmitType(alias->type(), "")) {
return false; return false;
} }
out_ << " " << namer_.NameFor(alias->name()) << ";" << std::endl; out_ << " " << namer_.NameFor(alias->name()) << ";" << std::endl;
} else if (ty->Is<ast::type::Struct>()) { } else if (auto* str = ty->As<ast::type::Struct>()) {
if (!EmitStructType(ty->As<ast::type::Struct>())) { if (!EmitStructType(str)) {
return false; return false;
} }
} else { } else {
@ -940,17 +934,17 @@ bool GeneratorImpl::EmitZeroValue(ast::type::Type* type) {
out_ << "0"; out_ << "0";
} else if (type->Is<ast::type::U32>()) { } else if (type->Is<ast::type::U32>()) {
out_ << "0u"; out_ << "0u";
} else if (type->Is<ast::type::Vector>()) { } else if (auto* vec = type->As<ast::type::Vector>()) {
return EmitZeroValue(type->As<ast::type::Vector>()->type()); return EmitZeroValue(vec->type());
} else if (type->Is<ast::type::Matrix>()) { } else if (auto* mat = type->As<ast::type::Matrix>()) {
return EmitZeroValue(type->As<ast::type::Matrix>()->type()); return EmitZeroValue(mat->type());
} else if (type->Is<ast::type::Array>()) { } else if (auto* arr = type->As<ast::type::Array>()) {
out_ << "{"; out_ << "{";
if (!EmitZeroValue(type->As<ast::type::Array>()->type())) { if (!EmitZeroValue(arr->type())) {
return false; return false;
} }
out_ << "}"; out_ << "}";
} else if (type->Is<ast::type::Struct>()) { } else if (type->As<ast::type::Struct>()) {
out_ << "{}"; out_ << "{}";
} else { } else {
error_ = "Invalid type for zero emission: " + type->type_name(); error_ = "Invalid type for zero emission: " + type->type_name();
@ -965,14 +959,14 @@ bool GeneratorImpl::EmitScalarConstructor(
} }
bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { bool GeneratorImpl::EmitLiteral(ast::Literal* lit) {
if (lit->Is<ast::BoolLiteral>()) { if (auto* l = lit->As<ast::BoolLiteral>()) {
out_ << (lit->As<ast::BoolLiteral>()->IsTrue() ? "true" : "false"); out_ << (l->IsTrue() ? "true" : "false");
} else if (lit->Is<ast::FloatLiteral>()) { } else if (auto* fl = lit->As<ast::FloatLiteral>()) {
out_ << FloatToString(lit->As<ast::FloatLiteral>()->value()) << "f"; out_ << FloatToString(fl->value()) << "f";
} else if (lit->Is<ast::SintLiteral>()) { } else if (auto* sl = lit->As<ast::SintLiteral>()) {
out_ << lit->As<ast::SintLiteral>()->value(); out_ << sl->value();
} else if (lit->Is<ast::UintLiteral>()) { } else if (auto* ul = lit->As<ast::UintLiteral>()) {
out_ << lit->As<ast::UintLiteral>()->value() << "u"; out_ << ul->value() << "u";
} else { } else {
error_ = "unknown literal type"; error_ = "unknown literal type";
return false; return false;
@ -1286,11 +1280,11 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
} }
first = false; first = false;
if (!var->type()->Is<ast::type::AccessControl>()) { auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac == nullptr) {
error_ = "invalid type for storage buffer, expected access control"; error_ = "invalid type for storage buffer, expected access control";
return false; return false;
} }
auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac->IsReadOnly()) { if (ac->IsReadOnly()) {
out_ << "const "; out_ << "const ";
} }
@ -1447,11 +1441,11 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
auto* binding = data.second.binding; auto* binding = data.second.binding;
// auto* set = data.second.set; // auto* set = data.second.set;
if (!var->type()->Is<ast::type::AccessControl>()) { auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac == nullptr) {
error_ = "invalid type for storage buffer, expected access control"; error_ = "invalid type for storage buffer, expected access control";
return false; return false;
} }
auto* ac = var->type()->As<ast::type::AccessControl>();
if (ac->IsReadOnly()) { if (ac->IsReadOnly()) {
out_ << "const "; out_ << "const ";
} }
@ -1490,14 +1484,13 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
} }
bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const { bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const {
auto* decorated = var->As<ast::DecoratedVariable>();
bool in_or_out_struct_has_location = bool in_or_out_struct_has_location =
var->Is<ast::DecoratedVariable>() && decorated != nullptr && decorated->HasLocationDecoration() &&
var->As<ast::DecoratedVariable>()->HasLocationDecoration() &&
(var->storage_class() == ast::StorageClass::kInput || (var->storage_class() == ast::StorageClass::kInput ||
var->storage_class() == ast::StorageClass::kOutput); var->storage_class() == ast::StorageClass::kOutput);
bool in_struct_has_builtin = bool in_struct_has_builtin =
var->Is<ast::DecoratedVariable>() && decorated != nullptr && decorated->HasBuiltinDecoration() &&
var->As<ast::DecoratedVariable>()->HasBuiltinDecoration() &&
var->storage_class() == ast::StorageClass::kOutput; var->storage_class() == ast::StorageClass::kOutput;
return in_or_out_struct_has_location || in_struct_has_builtin; return in_or_out_struct_has_location || in_struct_has_builtin;
} }
@ -1793,21 +1786,18 @@ bool GeneratorImpl::EmitSwitch(ast::SwitchStatement* stmt) {
} }
bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) { bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
if (type->Is<ast::type::Alias>()) { if (auto* alias = type->As<ast::type::Alias>()) {
auto* alias = type->As<ast::type::Alias>();
out_ << namer_.NameFor(alias->name()); out_ << namer_.NameFor(alias->name());
} else if (type->Is<ast::type::Array>()) { } else if (auto* ary = type->As<ast::type::Array>()) {
auto* ary = type->As<ast::type::Array>();
ast::type::Type* base_type = ary; ast::type::Type* base_type = ary;
std::vector<uint32_t> sizes; std::vector<uint32_t> sizes;
while (base_type->Is<ast::type::Array>()) { while (auto* arr = base_type->As<ast::type::Array>()) {
if (base_type->As<ast::type::Array>()->IsRuntimeArray()) { if (arr->IsRuntimeArray()) {
sizes.push_back(1); sizes.push_back(1);
} else { } else {
sizes.push_back(base_type->As<ast::type::Array>()->size()); sizes.push_back(arr->size());
} }
base_type = base_type->As<ast::type::Array>()->type(); base_type = arr->type();
} }
if (!EmitType(base_type, "")) { if (!EmitType(base_type, "")) {
return false; return false;
@ -1824,14 +1814,12 @@ bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
out_ << "float"; out_ << "float";
} else if (type->Is<ast::type::I32>()) { } else if (type->Is<ast::type::I32>()) {
out_ << "int"; out_ << "int";
} else if (type->Is<ast::type::Matrix>()) { } else if (auto* mat = type->As<ast::type::Matrix>()) {
auto* mat = type->As<ast::type::Matrix>();
if (!EmitType(mat->type(), "")) { if (!EmitType(mat->type(), "")) {
return false; return false;
} }
out_ << mat->columns() << "x" << mat->rows(); out_ << mat->columns() << "x" << mat->rows();
} else if (type->Is<ast::type::Pointer>()) { } else if (auto* ptr = type->As<ast::type::Pointer>()) {
auto* ptr = type->As<ast::type::Pointer>();
// TODO(dsinclair): Storage class? // TODO(dsinclair): Storage class?
if (!EmitType(ptr->type(), "")) { if (!EmitType(ptr->type(), "")) {
return false; return false;
@ -1839,13 +1827,11 @@ bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
out_ << "*"; out_ << "*";
} else if (type->Is<ast::type::Sampler>()) { } else if (type->Is<ast::type::Sampler>()) {
out_ << "sampler"; out_ << "sampler";
} else if (type->Is<ast::type::Struct>()) { } else if (auto* str = type->As<ast::type::Struct>()) {
// The struct type emits as just the name. The declaration would be emitted // The struct type emits as just the name. The declaration would be emitted
// as part of emitting the constructed types. // as part of emitting the constructed types.
out_ << type->As<ast::type::Struct>()->name(); out_ << str->name();
} else if (type->Is<ast::type::Texture>()) { } else if (auto* tex = type->As<ast::type::Texture>()) {
auto* tex = type->As<ast::type::Texture>();
if (tex->Is<ast::type::DepthTexture>()) { if (tex->Is<ast::type::DepthTexture>()) {
out_ << "depth"; out_ << "depth";
} else { } else {
@ -1884,8 +1870,7 @@ bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
out_ << "<"; out_ << "<";
if (tex->Is<ast::type::DepthTexture>()) { if (tex->Is<ast::type::DepthTexture>()) {
out_ << "float, access::sample"; out_ << "float, access::sample";
} else if (tex->Is<ast::type::StorageTexture>()) { } else if (auto* storage = tex->As<ast::type::StorageTexture>()) {
auto* storage = tex->As<ast::type::StorageTexture>();
if (!EmitType(storage->type(), "")) { if (!EmitType(storage->type(), "")) {
return false; return false;
} }
@ -1898,13 +1883,13 @@ bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
error_ = "Invalid access control for storage texture"; error_ = "Invalid access control for storage texture";
return false; return false;
} }
} else if (tex->Is<ast::type::MultisampledTexture>()) { } else if (auto* ms = tex->As<ast::type::MultisampledTexture>()) {
if (!EmitType(tex->As<ast::type::MultisampledTexture>()->type(), "")) { if (!EmitType(ms->type(), "")) {
return false; return false;
} }
out_ << ", access::sample"; out_ << ", access::sample";
} else if (tex->Is<ast::type::SampledTexture>()) { } else if (auto* sampled = tex->As<ast::type::SampledTexture>()) {
if (!EmitType(tex->As<ast::type::SampledTexture>()->type(), "")) { if (!EmitType(sampled->type(), "")) {
return false; return false;
} }
out_ << ", access::sample"; out_ << ", access::sample";
@ -1916,8 +1901,7 @@ bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
} else if (type->Is<ast::type::U32>()) { } else if (type->Is<ast::type::U32>()) {
out_ << "uint"; out_ << "uint";
} else if (type->Is<ast::type::Vector>()) { } else if (auto* vec = type->As<ast::type::Vector>()) {
auto* vec = type->As<ast::type::Vector>();
if (!EmitType(vec->type(), "")) { if (!EmitType(vec->type(), "")) {
return false; return false;
} }
@ -2044,8 +2028,8 @@ bool GeneratorImpl::EmitVariable(ast::Variable* var, bool skip_constructor) {
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
make_indent(); make_indent();
if (var->Is<ast::DecoratedVariable>() && auto* decorated = var->As<ast::DecoratedVariable>();
!var->As<ast::DecoratedVariable>()->HasConstantIdDecoration()) { if (decorated != nullptr && !decorated->HasConstantIdDecoration()) {
error_ = "Decorated const values not valid"; error_ = "Decorated const values not valid";
return false; return false;
} }
@ -2062,10 +2046,8 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
out_ << " " << var->name(); out_ << " " << var->name();
} }
if (var->Is<ast::DecoratedVariable>() && if (decorated != nullptr && decorated->HasConstantIdDecoration()) {
var->As<ast::DecoratedVariable>()->HasConstantIdDecoration()) { out_ << " [[function_constant(" << decorated->constant_id() << ")]]";
out_ << " [[function_constant("
<< var->As<ast::DecoratedVariable>()->constant_id() << ")]]";
} else if (var->constructor() != nullptr) { } else if (var->constructor() != nullptr) {
out_ << " = "; out_ << " = ";
if (!EmitExpression(var->constructor())) { if (!EmitExpression(var->constructor())) {

View File

@ -26,14 +26,12 @@ namespace writer {
namespace { namespace {
ast::TypeConstructorExpression* AsVectorConstructor(ast::Expression* expr) { ast::TypeConstructorExpression* AsVectorConstructor(ast::Expression* expr) {
auto* type_constructor = expr->As<ast::TypeConstructorExpression>(); if (auto* constructor = expr->As<ast::TypeConstructorExpression>()) {
if (type_constructor == nullptr) { if (constructor->type()->Is<ast::type::Vector>()) {
return nullptr; return constructor;
}
} }
if (!type_constructor->type()->Is<ast::type::Vector>()) { return nullptr;
return nullptr;
}
return type_constructor;
} }
} // namespace } // namespace
@ -44,8 +42,7 @@ bool PackCoordAndArrayIndex(
std::function<bool(ast::TypeConstructorExpression*)> callback) { std::function<bool(ast::TypeConstructorExpression*)> callback) {
uint32_t packed_size; uint32_t packed_size;
ast::type::Type* packed_el_ty; // Currenly must be f32. ast::type::Type* packed_el_ty; // Currenly must be f32.
if (coords->result_type()->Is<ast::type::Vector>()) { if (auto* vec = coords->result_type()->As<ast::type::Vector>()) {
auto* vec = coords->result_type()->As<ast::type::Vector>();
packed_size = vec->size() + 1; packed_size = vec->size() + 1;
packed_el_ty = vec->type(); packed_el_ty = vec->type();
} else { } else {

View File

@ -152,11 +152,10 @@ uint32_t IndexFromName(char name) {
/// @param type the given type, which must not be null /// @param type the given type, which must not be null
/// @returns the nested matrix type, or nullptr if none /// @returns the nested matrix type, or nullptr if none
ast::type::Matrix* GetNestedMatrixType(ast::type::Type* type) { ast::type::Matrix* GetNestedMatrixType(ast::type::Type* type) {
while (type->Is<ast::type::Array>()) { while (auto* arr = type->As<ast::type::Array>()) {
type = type->As<ast::type::Array>()->type(); type = arr->type();
} }
return type->Is<ast::type::Matrix>() ? type->As<ast::type::Matrix>() return type->As<ast::type::Matrix>();
: nullptr;
} }
uint32_t intrinsic_to_glsl_method(ast::type::Type* type, uint32_t intrinsic_to_glsl_method(ast::type::Type* type,
@ -721,10 +720,10 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) {
Operand::Int(ConvertStorageClass(sc))}; Operand::Int(ConvertStorageClass(sc))};
if (var->has_constructor()) { if (var->has_constructor()) {
ops.push_back(Operand::Int(init_id)); ops.push_back(Operand::Int(init_id));
} else if (type->Is<ast::type::Texture>()) { } else if (auto* tex = type->As<ast::type::Texture>()) {
// Decorate storage texture variables with NonRead/Writeable if needed. // Decorate storage texture variables with NonRead/Writeable if needed.
if (type->Is<ast::type::StorageTexture>()) { if (auto* storage = tex->As<ast::type::StorageTexture>()) {
switch (type->As<ast::type::StorageTexture>()->access()) { switch (storage->access()) {
case ast::AccessControl::kWriteOnly: case ast::AccessControl::kWriteOnly:
push_annot( push_annot(
spv::Op::OpDecorate, spv::Op::OpDecorate,
@ -747,8 +746,8 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) {
// one // one
// 2- If we don't have a constructor and we're an Output or Private variable // 2- If we don't have a constructor and we're an Output or Private variable
// then WGSL requires an initializer. // then WGSL requires an initializer.
if (var->Is<ast::DecoratedVariable>() && auto* decorated = var->As<ast::DecoratedVariable>();
var->As<ast::DecoratedVariable>()->HasConstantIdDecoration()) { if (decorated != nullptr && decorated->HasConstantIdDecoration()) {
if (type->Is<ast::type::F32>()) { if (type->Is<ast::type::F32>()) {
ast::FloatLiteral l(type, 0.0f); ast::FloatLiteral l(type, 0.0f);
init_id = GenerateLiteralIfNeeded(var, &l); init_id = GenerateLiteralIfNeeded(var, &l);
@ -1197,30 +1196,25 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) {
return false; return false;
} }
if (result_type->Is<ast::type::Vector>() && auto* sc = e->As<ast::ScalarConstructorExpression>();
!e->Is<ast::ScalarConstructorExpression>()) { if (result_type->Is<ast::type::Vector>() && sc == nullptr) {
return false; return false;
} }
// This should all be handled by |is_constructor_const| call above // This should all be handled by |is_constructor_const| call above
if (!e->Is<ast::ScalarConstructorExpression>()) { if (sc == nullptr) {
continue; continue;
} }
auto* sc = e->As<ast::ScalarConstructorExpression>();
ast::type::Type* subtype = result_type->UnwrapAll(); ast::type::Type* subtype = result_type->UnwrapAll();
if (subtype->Is<ast::type::Vector>()) { if (auto* vec = subtype->As<ast::type::Vector>()) {
subtype = subtype->As<ast::type::Vector>()->type()->UnwrapAll(); subtype = vec->type()->UnwrapAll();
} else if (subtype->Is<ast::type::Matrix>()) { } else if (auto* mat = subtype->As<ast::type::Matrix>()) {
subtype = subtype->As<ast::type::Matrix>()->type()->UnwrapAll(); subtype = mat->type()->UnwrapAll();
} else if (subtype->Is<ast::type::Array>()) { } else if (auto* arr = subtype->As<ast::type::Array>()) {
subtype = subtype->As<ast::type::Array>()->type()->UnwrapAll(); subtype = arr->type()->UnwrapAll();
} else if (subtype->Is<ast::type::Struct>()) { } else if (auto* str = subtype->As<ast::type::Struct>()) {
subtype = subtype->As<ast::type::Struct>() subtype = str->impl()->members()[i]->type()->UnwrapAll();
->impl()
->members()[i]
->type()
->UnwrapAll();
} }
if (subtype != sc->result_type()->UnwrapAll()) { if (subtype != sc->result_type()->UnwrapAll()) {
return false; return false;
@ -1251,15 +1245,17 @@ uint32_t Builder::GenerateTypeConstructorExpression(
bool can_cast_or_copy = result_type->is_scalar(); bool can_cast_or_copy = result_type->is_scalar();
if (result_type->Is<ast::type::Vector>() && if (auto* res_vec = result_type->As<ast::type::Vector>()) {
result_type->As<ast::type::Vector>()->type()->is_scalar()) { if (res_vec->type()->is_scalar()) {
auto* value_type = values[0]->result_type()->UnwrapAll(); auto* value_type = values[0]->result_type()->UnwrapAll();
can_cast_or_copy = if (auto* val_vec = value_type->As<ast::type::Vector>()) {
(value_type->Is<ast::type::Vector>() && if (val_vec->type()->is_scalar()) {
value_type->As<ast::type::Vector>()->type()->is_scalar() && can_cast_or_copy = res_vec->size() == val_vec->size();
result_type->As<ast::type::Vector>()->size() == }
value_type->As<ast::type::Vector>()->size()); }
}
} }
if (can_cast_or_copy) { if (can_cast_or_copy) {
return GenerateCastOrCopyOrPassthrough(result_type, values[0]); return GenerateCastOrCopyOrPassthrough(result_type, values[0]);
} }
@ -1272,8 +1268,8 @@ uint32_t Builder::GenerateTypeConstructorExpression(
bool result_is_constant_composite = constructor_is_const; bool result_is_constant_composite = constructor_is_const;
bool result_is_spec_composite = false; bool result_is_spec_composite = false;
if (result_type->Is<ast::type::Vector>()) { if (auto* vec = result_type->As<ast::type::Vector>()) {
result_type = result_type->As<ast::type::Vector>()->type(); result_type = vec->type();
} }
OperandList ops; OperandList ops;
@ -1321,8 +1317,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
// //
// For cases 1 and 2, if the type is different we also may need to insert // For cases 1 and 2, if the type is different we also may need to insert
// a type cast. // a type cast.
if (value_type->Is<ast::type::Vector>()) { if (auto* vec = value_type->As<ast::type::Vector>()) {
auto* vec = value_type->As<ast::type::Vector>();
auto* vec_type = vec->type(); auto* vec_type = vec->type();
auto value_type_id = GenerateTypeIfNeeded(vec_type); auto value_type_id = GenerateTypeIfNeeded(vec_type);
@ -1488,8 +1483,8 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
Operand::Int(var->As<ast::DecoratedVariable>()->constant_id())}); Operand::Int(var->As<ast::DecoratedVariable>()->constant_id())});
} }
if (lit->Is<ast::BoolLiteral>()) { if (auto* l = lit->As<ast::BoolLiteral>()) {
if (lit->As<ast::BoolLiteral>()->IsTrue()) { if (l->IsTrue()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue
: spv::Op::OpConstantTrue, : spv::Op::OpConstantTrue,
{Operand::Int(type_id), result}); {Operand::Int(type_id), result});
@ -1498,18 +1493,15 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
: spv::Op::OpConstantFalse, : spv::Op::OpConstantFalse,
{Operand::Int(type_id), result}); {Operand::Int(type_id), result});
} }
} else if (lit->Is<ast::SintLiteral>()) { } else if (auto* sl = lit->As<ast::SintLiteral>()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, {Operand::Int(type_id), result, Operand::Int(sl->value())});
Operand::Int(lit->As<ast::SintLiteral>()->value())}); } else if (auto* ul = lit->As<ast::UintLiteral>()) {
} else if (lit->Is<ast::UintLiteral>()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, {Operand::Int(type_id), result, Operand::Int(ul->value())});
Operand::Int(lit->As<ast::UintLiteral>()->value())}); } else if (auto* fl = lit->As<ast::FloatLiteral>()) {
} else if (lit->Is<ast::FloatLiteral>()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, {Operand::Int(type_id), result, Operand::Float(fl->value())});
Operand::Float(lit->As<ast::FloatLiteral>()->value())});
} else if (lit->Is<ast::NullLiteral>()) { } else if (lit->Is<ast::NullLiteral>()) {
push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result}); push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
} else { } else {
@ -2413,8 +2405,8 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) {
} }
// The alias is a wrapper around the subtype, so emit the subtype // The alias is a wrapper around the subtype, so emit the subtype
if (type->Is<ast::type::Alias>()) { if (auto* alias = type->As<ast::type::Alias>()) {
return GenerateTypeIfNeeded(type->As<ast::type::Alias>()->type()); return GenerateTypeIfNeeded(alias->type());
} }
auto val = type_name_to_id_.find(type->type_name()); auto val = type_name_to_id_.find(type->type_name());
@ -2425,8 +2417,7 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) {
auto result = result_op(); auto result = result_op();
auto id = result.to_i(); auto id = result.to_i();
if (type->Is<ast::type::AccessControl>()) { if (auto* ac = type->As<ast::type::AccessControl>()) {
auto* ac = type->As<ast::type::AccessControl>();
auto* subtype = ac->type()->UnwrapIfNeeded(); auto* subtype = ac->type()->UnwrapIfNeeded();
if (!subtype->Is<ast::type::Struct>()) { if (!subtype->Is<ast::type::Struct>()) {
error_ = "Access control attached to non-struct type."; error_ = "Access control attached to non-struct type.";
@ -2436,8 +2427,8 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) {
ac->access_control(), result)) { ac->access_control(), result)) {
return 0; return 0;
} }
} else if (type->Is<ast::type::Array>()) { } else if (auto* arr = type->As<ast::type::Array>()) {
if (!GenerateArray(type->As<ast::type::Array>(), result)) { if (!GenerateArrayType(arr, result)) {
return 0; return 0;
} }
} else if (type->Is<ast::type::Bool>()) { } else if (type->Is<ast::type::Bool>()) {
@ -2446,29 +2437,28 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) {
push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)}); push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
} else if (type->Is<ast::type::I32>()) { } else if (type->Is<ast::type::I32>()) {
push_type(spv::Op::OpTypeInt, {result, Operand::Int(32), Operand::Int(1)}); push_type(spv::Op::OpTypeInt, {result, Operand::Int(32), Operand::Int(1)});
} else if (type->Is<ast::type::Matrix>()) { } else if (auto* mat = type->As<ast::type::Matrix>()) {
if (!GenerateMatrixType(type->As<ast::type::Matrix>(), result)) { if (!GenerateMatrixType(mat, result)) {
return 0; return 0;
} }
} else if (type->Is<ast::type::Pointer>()) { } else if (auto* ptr = type->As<ast::type::Pointer>()) {
if (!GeneratePointerType(type->As<ast::type::Pointer>(), result)) { if (!GeneratePointerType(ptr, result)) {
return 0; return 0;
} }
} else if (type->Is<ast::type::Struct>()) { } else if (auto* str = type->As<ast::type::Struct>()) {
if (!GenerateStructType(type->As<ast::type::Struct>(), if (!GenerateStructType(str, ast::AccessControl::kReadWrite, result)) {
ast::AccessControl::kReadWrite, result)) {
return 0; return 0;
} }
} else if (type->Is<ast::type::U32>()) { } else if (type->Is<ast::type::U32>()) {
push_type(spv::Op::OpTypeInt, {result, Operand::Int(32), Operand::Int(0)}); push_type(spv::Op::OpTypeInt, {result, Operand::Int(32), Operand::Int(0)});
} else if (type->Is<ast::type::Vector>()) { } else if (auto* vec = type->As<ast::type::Vector>()) {
if (!GenerateVectorType(type->As<ast::type::Vector>(), result)) { if (!GenerateVectorType(vec, result)) {
return 0; return 0;
} }
} else if (type->Is<ast::type::Void>()) { } else if (type->Is<ast::type::Void>()) {
push_type(spv::Op::OpTypeVoid, {result}); push_type(spv::Op::OpTypeVoid, {result});
} else if (type->Is<ast::type::Texture>()) { } else if (auto* tex = type->As<ast::type::Texture>()) {
if (!GenerateTextureType(type->As<ast::type::Texture>(), result)) { if (!GenerateTextureType(tex, result)) {
return 0; return 0;
} }
} else if (type->Is<ast::type::Sampler>()) { } else if (type->Is<ast::type::Sampler>()) {
@ -2546,20 +2536,16 @@ bool Builder::GenerateTextureType(ast::type::Texture* texture,
if (texture->Is<ast::type::DepthTexture>()) { if (texture->Is<ast::type::DepthTexture>()) {
ast::type::F32 f32; ast::type::F32 f32;
type_id = GenerateTypeIfNeeded(&f32); type_id = GenerateTypeIfNeeded(&f32);
} else if (texture->Is<ast::type::SampledTexture>()) { } else if (auto* s = texture->As<ast::type::SampledTexture>()) {
type_id = type_id = GenerateTypeIfNeeded(s->type());
GenerateTypeIfNeeded(texture->As<ast::type::SampledTexture>()->type()); } else if (auto* ms = texture->As<ast::type::MultisampledTexture>()) {
} else if (texture->Is<ast::type::MultisampledTexture>()) { type_id = GenerateTypeIfNeeded(ms->type());
type_id = GenerateTypeIfNeeded( } else if (auto* st = texture->As<ast::type::StorageTexture>()) {
texture->As<ast::type::MultisampledTexture>()->type()); if (st->access() == ast::AccessControl::kWriteOnly) {
} else if (texture->Is<ast::type::StorageTexture>()) {
if (texture->As<ast::type::StorageTexture>()->access() ==
ast::AccessControl::kWriteOnly) {
ast::type::Void void_type; ast::type::Void void_type;
type_id = GenerateTypeIfNeeded(&void_type); type_id = GenerateTypeIfNeeded(&void_type);
} else { } else {
type_id = GenerateTypeIfNeeded( type_id = GenerateTypeIfNeeded(st->type());
texture->As<ast::type::StorageTexture>()->type());
} }
} }
if (type_id == 0u) { if (type_id == 0u) {
@ -2567,9 +2553,8 @@ bool Builder::GenerateTextureType(ast::type::Texture* texture,
} }
uint32_t format_literal = SpvImageFormat_::SpvImageFormatUnknown; uint32_t format_literal = SpvImageFormat_::SpvImageFormatUnknown;
if (texture->Is<ast::type::StorageTexture>()) { if (auto* t = texture->As<ast::type::StorageTexture>()) {
format_literal = convert_image_format_to_spv( format_literal = convert_image_format_to_spv(t->image_format());
texture->As<ast::type::StorageTexture>()->image_format());
} }
push_type(spv::Op::OpTypeImage, push_type(spv::Op::OpTypeImage,
@ -2581,7 +2566,7 @@ bool Builder::GenerateTextureType(ast::type::Texture* texture,
return true; return true;
} }
bool Builder::GenerateArray(ast::type::Array* ary, const Operand& result) { bool Builder::GenerateArrayType(ast::type::Array* ary, const Operand& result) {
auto elem_type = GenerateTypeIfNeeded(ary->type()); auto elem_type = GenerateTypeIfNeeded(ary->type());
if (elem_type == 0) { if (elem_type == 0) {
return false; return false;

View File

@ -431,7 +431,7 @@ class Builder {
/// @param ary the array to generate /// @param ary the array to generate
/// @param result the result operand /// @param result the result operand
/// @returns true if the array was successfully generated /// @returns true if the array was successfully generated
bool GenerateArray(ast::type::Array* ary, const Operand& result); bool GenerateArrayType(ast::type::Array* ary, const Operand& result);
/// Generates a matrix type declaration /// Generates a matrix type declaration
/// @param mat the matrix to generate /// @param mat the matrix to generate
/// @param result the result operand /// @param result the result operand

View File

@ -173,15 +173,14 @@ bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module,
bool GeneratorImpl::EmitConstructedType(const ast::type::Type* ty) { bool GeneratorImpl::EmitConstructedType(const ast::type::Type* ty) {
make_indent(); make_indent();
if (ty->Is<ast::type::Alias>()) { if (auto* alias = ty->As<ast::type::Alias>()) {
auto* alias = ty->As<ast::type::Alias>();
out_ << "type " << alias->name() << " = "; out_ << "type " << alias->name() << " = ";
if (!EmitType(alias->type())) { if (!EmitType(alias->type())) {
return false; return false;
} }
out_ << ";" << std::endl; out_ << ";" << std::endl;
} else if (ty->Is<ast::type::Struct>()) { } else if (auto* str = ty->As<ast::type::Struct>()) {
if (!EmitStructType(ty->As<ast::type::Struct>())) { if (!EmitStructType(str)) {
return false; return false;
} }
} else { } else {
@ -321,14 +320,14 @@ bool GeneratorImpl::EmitScalarConstructor(
} }
bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { bool GeneratorImpl::EmitLiteral(ast::Literal* lit) {
if (lit->Is<ast::BoolLiteral>()) { if (auto* bl = lit->As<ast::BoolLiteral>()) {
out_ << (lit->As<ast::BoolLiteral>()->IsTrue() ? "true" : "false"); out_ << (bl->IsTrue() ? "true" : "false");
} else if (lit->Is<ast::FloatLiteral>()) { } else if (auto* fl = lit->As<ast::FloatLiteral>()) {
out_ << FloatToString(lit->As<ast::FloatLiteral>()->value()); out_ << FloatToString(fl->value());
} else if (lit->Is<ast::SintLiteral>()) { } else if (auto* sl = lit->As<ast::SintLiteral>()) {
out_ << lit->As<ast::SintLiteral>()->value(); out_ << sl->value();
} else if (lit->Is<ast::UintLiteral>()) { } else if (auto* ul = lit->As<ast::UintLiteral>()) {
out_ << lit->As<ast::UintLiteral>()->value() << "u"; out_ << ul->value() << "u";
} else { } else {
error_ = "unknown literal type"; error_ = "unknown literal type";
return false; return false;
@ -399,9 +398,7 @@ bool GeneratorImpl::EmitImageFormat(const ast::type::ImageFormat fmt) {
} }
bool GeneratorImpl::EmitType(ast::type::Type* type) { bool GeneratorImpl::EmitType(ast::type::Type* type) {
if (type->Is<ast::type::AccessControl>()) { if (auto* ac = type->As<ast::type::AccessControl>()) {
auto* ac = type->As<ast::type::AccessControl>();
out_ << "[[access("; out_ << "[[access(";
if (ac->IsReadOnly()) { if (ac->IsReadOnly()) {
out_ << "read"; out_ << "read";
@ -415,11 +412,9 @@ bool GeneratorImpl::EmitType(ast::type::Type* type) {
if (!EmitType(ac->type())) { if (!EmitType(ac->type())) {
return false; return false;
} }
} else if (type->Is<ast::type::Alias>()) { } else if (auto* alias = type->As<ast::type::Alias>()) {
out_ << type->As<ast::type::Alias>()->name(); out_ << alias->name();
} else if (type->Is<ast::type::Array>()) { } else if (auto* ary = type->As<ast::type::Array>()) {
auto* ary = type->As<ast::type::Array>();
for (auto* deco : ary->decorations()) { for (auto* deco : ary->decorations()) {
if (auto* stride = deco->As<ast::StrideDecoration>()) { if (auto* stride = deco->As<ast::StrideDecoration>()) {
out_ << "[[stride(" << stride->stride() << ")]] "; out_ << "[[stride(" << stride->stride() << ")]] ";
@ -441,34 +436,29 @@ bool GeneratorImpl::EmitType(ast::type::Type* type) {
out_ << "f32"; out_ << "f32";
} else if (type->Is<ast::type::I32>()) { } else if (type->Is<ast::type::I32>()) {
out_ << "i32"; out_ << "i32";
} else if (type->Is<ast::type::Matrix>()) { } else if (auto* mat = type->As<ast::type::Matrix>()) {
auto* mat = type->As<ast::type::Matrix>();
out_ << "mat" << mat->columns() << "x" << mat->rows() << "<"; out_ << "mat" << mat->columns() << "x" << mat->rows() << "<";
if (!EmitType(mat->type())) { if (!EmitType(mat->type())) {
return false; return false;
} }
out_ << ">"; out_ << ">";
} else if (type->Is<ast::type::Pointer>()) { } else if (auto* ptr = type->As<ast::type::Pointer>()) {
auto* ptr = type->As<ast::type::Pointer>();
out_ << "ptr<" << ptr->storage_class() << ", "; out_ << "ptr<" << ptr->storage_class() << ", ";
if (!EmitType(ptr->type())) { if (!EmitType(ptr->type())) {
return false; return false;
} }
out_ << ">"; out_ << ">";
} else if (type->Is<ast::type::Sampler>()) { } else if (auto* sampler = type->As<ast::type::Sampler>()) {
auto* sampler = type->As<ast::type::Sampler>();
out_ << "sampler"; out_ << "sampler";
if (sampler->IsComparison()) { if (sampler->IsComparison()) {
out_ << "_comparison"; out_ << "_comparison";
} }
} else if (type->Is<ast::type::Struct>()) { } else if (auto* str = type->As<ast::type::Struct>()) {
// The struct, as a type, is just the name. We should have already emitted // The struct, as a type, is just the name. We should have already emitted
// the declaration through a call to |EmitStructType| earlier. // the declaration through a call to |EmitStructType| earlier.
out_ << type->As<ast::type::Struct>()->name(); out_ << str->name();
} else if (type->Is<ast::type::Texture>()) { } else if (auto* texture = type->As<ast::type::Texture>()) {
auto* texture = type->As<ast::type::Texture>();
out_ << "texture_"; out_ << "texture_";
if (texture->Is<ast::type::DepthTexture>()) { if (texture->Is<ast::type::DepthTexture>()) {
out_ << "depth_"; out_ << "depth_";
@ -476,10 +466,8 @@ bool GeneratorImpl::EmitType(ast::type::Type* type) {
/* nothing to emit */ /* nothing to emit */
} else if (texture->Is<ast::type::MultisampledTexture>()) { } else if (texture->Is<ast::type::MultisampledTexture>()) {
out_ << "multisampled_"; out_ << "multisampled_";
} else if (texture->Is<ast::type::StorageTexture>()) { } else if (auto* storage = texture->As<ast::type::StorageTexture>()) {
out_ << "storage_"; out_ << "storage_";
auto* storage = texture->As<ast::type::StorageTexture>();
if (storage->access() == ast::AccessControl::kReadOnly) { if (storage->access() == ast::AccessControl::kReadOnly) {
out_ << "ro_"; out_ << "ro_";
} else if (storage->access() == ast::AccessControl::kWriteOnly) { } else if (storage->access() == ast::AccessControl::kWriteOnly) {
@ -520,25 +508,19 @@ bool GeneratorImpl::EmitType(ast::type::Type* type) {
return false; return false;
} }
if (texture->Is<ast::type::SampledTexture>()) { if (auto* sampled = texture->As<ast::type::SampledTexture>()) {
auto* sampled = texture->As<ast::type::SampledTexture>();
out_ << "<"; out_ << "<";
if (!EmitType(sampled->type())) { if (!EmitType(sampled->type())) {
return false; return false;
} }
out_ << ">"; out_ << ">";
} else if (texture->Is<ast::type::MultisampledTexture>()) { } else if (auto* ms = texture->As<ast::type::MultisampledTexture>()) {
auto* sampled = texture->As<ast::type::MultisampledTexture>();
out_ << "<"; out_ << "<";
if (!EmitType(sampled->type())) { if (!EmitType(ms->type())) {
return false; return false;
} }
out_ << ">"; out_ << ">";
} else if (texture->Is<ast::type::StorageTexture>()) { } else if (auto* storage = texture->As<ast::type::StorageTexture>()) {
auto* storage = texture->As<ast::type::StorageTexture>();
out_ << "<"; out_ << "<";
if (!EmitImageFormat(storage->image_format())) { if (!EmitImageFormat(storage->image_format())) {
return false; return false;
@ -548,8 +530,7 @@ bool GeneratorImpl::EmitType(ast::type::Type* type) {
} else if (type->Is<ast::type::U32>()) { } else if (type->Is<ast::type::U32>()) {
out_ << "u32"; out_ << "u32";
} else if (type->Is<ast::type::Vector>()) { } else if (auto* vec = type->As<ast::type::Vector>()) {
auto* vec = type->As<ast::type::Vector>();
out_ << "vec" << vec->size() << "<"; out_ << "vec" << vec->size() << "<";
if (!EmitType(vec->type())) { if (!EmitType(vec->type())) {
return false; return false;
@ -580,10 +561,9 @@ bool GeneratorImpl::EmitStructType(const ast::type::Struct* str) {
make_indent(); make_indent();
// TODO(dsinclair): Split this out when we have more then one // TODO(dsinclair): Split this out when we have more then one
assert(deco->Is<ast::StructMemberOffsetDecoration>()); auto* offset = deco->As<ast::StructMemberOffsetDecoration>();
out_ << "[[offset(" assert(offset != nullptr);
<< deco->As<ast::StructMemberOffsetDecoration>()->offset() << ")]]" out_ << "[[offset(" << offset->offset() << ")]]" << std::endl;
<< std::endl;
} }
make_indent(); make_indent();
out_ << mem->name() << " : "; out_ << mem->name() << " : ";
@ -651,8 +631,8 @@ bool GeneratorImpl::EmitVariableDecorations(ast::DecoratedVariable* var) {
out_ << "location(" << location->value() << ")"; out_ << "location(" << location->value() << ")";
} else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) { } else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
out_ << "builtin(" << builtin->value() << ")"; out_ << "builtin(" << builtin->value() << ")";
} else if (auto* cid = deco->As<ast::ConstantIdDecoration>()) { } else if (auto* constant = deco->As<ast::ConstantIdDecoration>()) {
out_ << "constant_id(" << cid->value() << ")"; out_ << "constant_id(" << constant->value() << ")";
} else { } else {
error_ = "unknown variable decoration"; error_ = "unknown variable decoration";
return false; return false;