Convert `@location` to store expression internally.

This CL updates the internal storage for a `@location` attribute
to store the `Expression` instead of a raw `uint32_t`. The current
parser is updated to generate an `IntLiteralExpression` so we still
parse as a `uint32_t` at the moment.

Bug: tint:1633
Change-Id: I2b9684754a657b39554160c81727cf1541bee96c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101461
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-09-07 22:25:24 +00:00 committed by Dawn LUCI CQ
parent 145337f309
commit f9eeed6106
41 changed files with 387 additions and 249 deletions

View File

@ -22,7 +22,10 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::LocationAttribute);
namespace tint::ast { namespace tint::ast {
LocationAttribute::LocationAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t val) LocationAttribute::LocationAttribute(ProgramID pid,
NodeID nid,
const Source& src,
const ast::Expression* val)
: Base(pid, nid, src), value(val) {} : Base(pid, nid, src), value(val) {}
LocationAttribute::~LocationAttribute() = default; LocationAttribute::~LocationAttribute() = default;
@ -34,7 +37,8 @@ std::string LocationAttribute::Name() const {
const LocationAttribute* LocationAttribute::Clone(CloneContext* ctx) const { const LocationAttribute* LocationAttribute::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering // Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source); auto src = ctx->Clone(source);
return ctx->dst->create<LocationAttribute>(src, value); auto value_ = ctx->Clone(value);
return ctx->dst->create<LocationAttribute>(src, value_);
} }
} // namespace tint::ast } // namespace tint::ast

View File

@ -18,6 +18,7 @@
#include <string> #include <string>
#include "src/tint/ast/attribute.h" #include "src/tint/ast/attribute.h"
#include "src/tint/ast/expression.h"
namespace tint::ast { namespace tint::ast {
@ -28,8 +29,8 @@ class LocationAttribute final : public Castable<LocationAttribute, Attribute> {
/// @param pid the identifier of the program that owns this node /// @param pid the identifier of the program that owns this node
/// @param nid the unique node identifier /// @param nid the unique node identifier
/// @param src the source of this node /// @param src the source of this node
/// @param value the location value /// @param value the location value expression
LocationAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t value); LocationAttribute(ProgramID pid, NodeID nid, const Source& src, const ast::Expression* value);
~LocationAttribute() override; ~LocationAttribute() override;
/// @returns the WGSL name for the attribute /// @returns the WGSL name for the attribute
@ -42,7 +43,7 @@ class LocationAttribute final : public Castable<LocationAttribute, Attribute> {
const LocationAttribute* Clone(CloneContext* ctx) const override; const LocationAttribute* Clone(CloneContext* ctx) const override;
/// The location value /// The location value
const uint32_t value; const ast::Expression* const value;
}; };
} // namespace tint::ast } // namespace tint::ast

View File

@ -17,11 +17,12 @@
namespace tint::ast { namespace tint::ast {
namespace { namespace {
using namespace tint::number_suffixes; // NOLINT
using LocationAttributeTest = TestHelper; using LocationAttributeTest = TestHelper;
TEST_F(LocationAttributeTest, Creation) { TEST_F(LocationAttributeTest, Creation) {
auto* d = create<LocationAttribute>(2u); auto* d = Location(2_a);
EXPECT_EQ(2u, d->value); EXPECT_TRUE(d->value->Is<IntLiteralExpression>());
} }
} // namespace } // namespace

View File

@ -92,7 +92,7 @@ TEST_F(VariableTest, Assert_DifferentProgramID_Constructor) {
} }
TEST_F(VariableTest, WithAttributes) { TEST_F(VariableTest, WithAttributes) {
auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Location(1u), auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Location(1_u),
Builtin(BuiltinValue::kPosition), Id(1200_u)); Builtin(BuiltinValue::kPosition), Id(1200_u));
auto& attributes = var->attributes; auto& attributes = var->attributes;
@ -102,7 +102,8 @@ TEST_F(VariableTest, WithAttributes) {
auto* location = ast::GetAttribute<ast::LocationAttribute>(attributes); auto* location = ast::GetAttribute<ast::LocationAttribute>(attributes);
ASSERT_NE(nullptr, location); ASSERT_NE(nullptr, location);
EXPECT_EQ(1u, location->value); ASSERT_NE(nullptr, location->value);
EXPECT_TRUE(location->value->Is<ast::IntLiteralExpression>());
} }
TEST_F(VariableTest, HasBindingPoint_BothProvided) { TEST_F(VariableTest, HasBindingPoint_BothProvided) {

View File

@ -172,7 +172,7 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
for (auto* param : sem->Parameters()) { for (auto* param : sem->Parameters()) {
AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol), AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol),
param->Type(), param->Declaration()->attributes, param->Type(), param->Declaration()->attributes,
entry_point.input_variables); param->Location(), entry_point.input_variables);
entry_point.input_position_used |= ContainsBuiltin( entry_point.input_position_used |= ContainsBuiltin(
ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes); ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes);
@ -188,7 +188,7 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
if (!sem->ReturnType()->Is<sem::Void>()) { if (!sem->ReturnType()->Is<sem::Void>()) {
AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes, AddEntryPointInOutVariables("<retval>", sem->ReturnType(), func->return_type_attributes,
entry_point.output_variables); sem->ReturnLocation(), entry_point.output_variables);
entry_point.output_sample_mask_used = ContainsBuiltin( entry_point.output_sample_mask_used = ContainsBuiltin(
ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes); ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes);
@ -623,6 +623,7 @@ const ast::Function* Inspector::FindEntryPointByName(const std::string& name) {
void Inspector::AddEntryPointInOutVariables(std::string name, void Inspector::AddEntryPointInOutVariables(std::string name,
const sem::Type* type, const sem::Type* type,
utils::VectorRef<const ast::Attribute*> attributes, utils::VectorRef<const ast::Attribute*> attributes,
std::optional<uint32_t> location,
std::vector<StageVariable>& variables) const { std::vector<StageVariable>& variables) const {
// Skip builtins. // Skip builtins.
if (ast::HasAttribute<ast::BuiltinAttribute>(attributes)) { if (ast::HasAttribute<ast::BuiltinAttribute>(attributes)) {
@ -636,7 +637,7 @@ void Inspector::AddEntryPointInOutVariables(std::string name,
for (auto* member : struct_ty->Members()) { for (auto* member : struct_ty->Members()) {
AddEntryPointInOutVariables( AddEntryPointInOutVariables(
name + "." + program_->Symbols().NameFor(member->Declaration()->symbol), name + "." + program_->Symbols().NameFor(member->Declaration()->symbol),
member->Type(), member->Declaration()->attributes, variables); member->Type(), member->Declaration()->attributes, member->Location(), variables);
} }
return; return;
} }
@ -648,10 +649,9 @@ void Inspector::AddEntryPointInOutVariables(std::string name,
std::tie(stage_variable.component_type, stage_variable.composition_type) = std::tie(stage_variable.component_type, stage_variable.composition_type) =
CalculateComponentAndComposition(type); CalculateComponentAndComposition(type);
auto* location = ast::GetAttribute<ast::LocationAttribute>(attributes); TINT_ASSERT(Inspector, location.has_value());
TINT_ASSERT(Inspector, location != nullptr);
stage_variable.has_location_attribute = true; stage_variable.has_location_attribute = true;
stage_variable.location_attribute = location->value; stage_variable.location_attribute = location.value();
std::tie(stage_variable.interpolation_type, stage_variable.interpolation_sampling) = std::tie(stage_variable.interpolation_type, stage_variable.interpolation_sampling) =
CalculateInterpolationData(type, attributes); CalculateInterpolationData(type, attributes);

View File

@ -172,10 +172,12 @@ class Inspector {
/// @param name the name of the variable being added /// @param name the name of the variable being added
/// @param type the type of the variable /// @param type the type of the variable
/// @param attributes the variable attributes /// @param attributes the variable attributes
/// @param location the location value if provided
/// @param variables the list to add the variables to /// @param variables the list to add the variables to
void AddEntryPointInOutVariables(std::string name, void AddEntryPointInOutVariables(std::string name,
const sem::Type* type, const sem::Type* type,
utils::VectorRef<const ast::Attribute*> attributes, utils::VectorRef<const ast::Attribute*> attributes,
std::optional<uint32_t> location,
std::vector<StageVariable>& variables) const; std::vector<StageVariable>& variables) const;
/// Recursively determine if the type contains builtin. /// Recursively determine if the type contains builtin.

View File

@ -291,7 +291,7 @@ TEST_P(InspectorGetEntryPointComponentAndCompositionTest, Test) {
auto* in_var = Param("in_var", tint_type(), auto* in_var = Param("in_var", tint_type(),
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
Flat(), Flat(),
}); });
Func("foo", utils::Vector{in_var}, tint_type(), Func("foo", utils::Vector{in_var}, tint_type(),
@ -302,7 +302,7 @@ TEST_P(InspectorGetEntryPointComponentAndCompositionTest, Test) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
}); });
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -336,17 +336,17 @@ INSTANTIATE_TEST_SUITE_P(InspectorGetEntryPointTest,
TEST_F(InspectorGetEntryPointTest, MultipleInOutVariables) { TEST_F(InspectorGetEntryPointTest, MultipleInOutVariables) {
auto* in_var0 = Param("in_var0", ty.u32(), auto* in_var0 = Param("in_var0", ty.u32(),
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
Flat(), Flat(),
}); });
auto* in_var1 = Param("in_var1", ty.u32(), auto* in_var1 = Param("in_var1", ty.u32(),
utils::Vector{ utils::Vector{
Location(1u), Location(1_u),
Flat(), Flat(),
}); });
auto* in_var4 = Param("in_var4", ty.u32(), auto* in_var4 = Param("in_var4", ty.u32(),
utils::Vector{ utils::Vector{
Location(4u), Location(4_u),
Flat(), Flat(),
}); });
Func("foo", utils::Vector{in_var0, in_var1, in_var4}, ty.u32(), Func("foo", utils::Vector{in_var0, in_var1, in_var4}, ty.u32(),
@ -357,7 +357,7 @@ TEST_F(InspectorGetEntryPointTest, MultipleInOutVariables) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
}); });
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -393,7 +393,7 @@ TEST_F(InspectorGetEntryPointTest, MultipleInOutVariables) {
TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) { TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) {
auto* in_var_foo = Param("in_var_foo", ty.u32(), auto* in_var_foo = Param("in_var_foo", ty.u32(),
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
Flat(), Flat(),
}); });
Func("foo", utils::Vector{in_var_foo}, ty.u32(), Func("foo", utils::Vector{in_var_foo}, ty.u32(),
@ -404,12 +404,12 @@ TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
}); });
auto* in_var_bar = Param("in_var_bar", ty.u32(), auto* in_var_bar = Param("in_var_bar", ty.u32(),
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
Flat(), Flat(),
}); });
Func("bar", utils::Vector{in_var_bar}, ty.u32(), Func("bar", utils::Vector{in_var_bar}, ty.u32(),
@ -420,7 +420,7 @@ TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(1u), Location(1_u),
}); });
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -464,7 +464,7 @@ TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) {
}); });
auto* in_var1 = Param("in_var1", ty.f32(), auto* in_var1 = Param("in_var1", ty.f32(),
utils::Vector{ utils::Vector{
Location(0u), Location(0_u),
}); });
Func("foo", utils::Vector{in_var0, in_var1}, ty.f32(), Func("foo", utils::Vector{in_var0, in_var1}, ty.f32(),
utils::Vector{ utils::Vector{
@ -596,8 +596,8 @@ TEST_F(InspectorGetEntryPointTest, MixInOutVariablesAndStruct) {
utils::Vector{ utils::Vector{
Param("param_a", ty.Of(struct_a)), Param("param_a", ty.Of(struct_a)),
Param("param_b", ty.Of(struct_b)), Param("param_b", ty.Of(struct_b)),
Param("param_c", ty.f32(), utils::Vector{Location(3u)}), Param("param_c", ty.f32(), utils::Vector{Location(3_u)}),
Param("param_d", ty.f32(), utils::Vector{Location(4u)}), Param("param_d", ty.f32(), utils::Vector{Location(4_u)}),
}, },
ty.Of(struct_a), ty.Of(struct_a),
utils::Vector{ utils::Vector{
@ -1136,7 +1136,7 @@ TEST_F(InspectorGetEntryPointTest, NumWorkgroupsStructReferenced) {
TEST_F(InspectorGetEntryPointTest, ImplicitInterpolate) { TEST_F(InspectorGetEntryPointTest, ImplicitInterpolate) {
Structure("in_struct", utils::Vector{ Structure("in_struct", utils::Vector{
Member("struct_inner", ty.f32(), utils::Vector{Location(0)}), Member("struct_inner", ty.f32(), utils::Vector{Location(0_a)}),
}); });
Func("ep_func", Func("ep_func",
@ -1167,7 +1167,7 @@ TEST_P(InspectorGetEntryPointInterpolateTest, Test) {
"in_struct", "in_struct",
utils::Vector{ utils::Vector{
Member("struct_inner", ty.f32(), Member("struct_inner", ty.f32(),
utils::Vector{Interpolate(params.in_type, params.in_sampling), Location(0)}), utils::Vector{Interpolate(params.in_type, params.in_sampling), Location(0_a)}),
}); });
Func("ep_func", Func("ep_func",

View File

@ -54,7 +54,7 @@ const ast::Struct* InspectorBuilder::MakeInOutStruct(std::string name,
std::tie(member_name, location) = var; std::tie(member_name, location) = var;
members.Push(Member(member_name, ty.u32(), members.Push(Member(member_name, ty.u32(),
utils::Vector{ utils::Vector{
Location(location), Location(AInt(location)),
Flat(), Flat(),
})); }));
} }

View File

@ -2928,17 +2928,19 @@ class ProgramBuilder {
/// Creates an ast::LocationAttribute /// Creates an ast::LocationAttribute
/// @param source the source information /// @param source the source information
/// @param location the location value /// @param location the location value expression
/// @returns the location attribute pointer /// @returns the location attribute pointer
const ast::LocationAttribute* Location(const Source& source, uint32_t location) { template <typename EXPR>
return create<ast::LocationAttribute>(source, location); const ast::LocationAttribute* Location(const Source& source, EXPR&& location) {
return create<ast::LocationAttribute>(source, Expr(std::forward<EXPR>(location)));
} }
/// Creates an ast::LocationAttribute /// Creates an ast::LocationAttribute
/// @param location the location value /// @param location the location value expression
/// @returns the location attribute pointer /// @returns the location attribute pointer
const ast::LocationAttribute* Location(uint32_t location) { template <typename EXPR>
return create<ast::LocationAttribute>(source_, location); const ast::LocationAttribute* Location(EXPR&& location) {
return create<ast::LocationAttribute>(source_, Expr(std::forward<EXPR>(location)));
} }
/// Creates an ast::IdAttribute /// Creates an ast::IdAttribute

View File

@ -1109,7 +1109,9 @@ void FunctionEmitter::IncrementLocation(AttributeList* attributes) {
// Replace this location attribute with a new one with one higher index. // Replace this location attribute with a new one with one higher index.
// The old one doesn't leak because it's kept in the builder's AST node // The old one doesn't leak because it's kept in the builder's AST node
// list. // list.
attr = builder_.Location(loc_attr->source, loc_attr->value + 1); attr = builder_.Location(
loc_attr->source,
AInt(loc_attr->value->As<ast::IntLiteralExpression>()->value + 1));
} }
} }
} }

View File

@ -1723,25 +1723,22 @@ DecorationList ParserImpl::GetMemberPipelineDecorations(const Struct& struct_typ
return result; return result;
} }
const ast::Attribute* ParserImpl::SetLocation(AttributeList* attributes, void ParserImpl::SetLocation(AttributeList* attributes, const ast::Attribute* replacement) {
const ast::Attribute* replacement) {
if (!replacement) { if (!replacement) {
return nullptr; return;
} }
for (auto*& attribute : *attributes) { for (auto*& attribute : *attributes) {
if (attribute->Is<ast::LocationAttribute>()) { if (attribute->Is<ast::LocationAttribute>()) {
// Replace this location attribute with the replacement. // Replace this location attribute with the replacement.
// The old one doesn't leak because it's kept in the builder's AST node // The old one doesn't leak because it's kept in the builder's AST node
// list. // list.
const ast::Attribute* result = nullptr;
result = attribute;
attribute = replacement; attribute = replacement;
return result; // Assume there is only one such decoration. return; // Assume there is only one such decoration.
} }
} }
// The list didn't have a location. Add it. // The list didn't have a location. Add it.
attributes->Push(replacement); attributes->Push(replacement);
return nullptr; return;
} }
bool ParserImpl::ConvertPipelineDecorations(const Type* store_type, bool ParserImpl::ConvertPipelineDecorations(const Type* store_type,
@ -1759,7 +1756,7 @@ bool ParserImpl::ConvertPipelineDecorations(const Type* store_type,
return Fail() << "malformed Location decoration on ID requires one " return Fail() << "malformed Location decoration on ID requires one "
"literal operand"; "literal operand";
} }
SetLocation(attributes, create<ast::LocationAttribute>(Source{}, deco[1])); SetLocation(attributes, builder_.Location(AInt(deco[1])));
if (store_type->IsIntegerScalarOrVector()) { if (store_type->IsIntegerScalarOrVector()) {
// Default to flat interpolation for integral user-defined IO types. // Default to flat interpolation for integral user-defined IO types.
type = ast::InterpolationType::kFlat; type = ast::InterpolationType::kFlat;

View File

@ -280,9 +280,7 @@ class ParserImpl : Reader {
/// Assumes the list contains at most one Location decoration. /// Assumes the list contains at most one Location decoration.
/// @param decos the attribute list to modify /// @param decos the attribute list to modify
/// @param replacement the location decoration to place into the list /// @param replacement the location decoration to place into the list
/// @returns the location decoration that was replaced, if one was replaced, void SetLocation(AttributeList* decos, const ast::Attribute* replacement);
/// or null otherwise.
const ast::Attribute* SetLocation(AttributeList* decos, const ast::Attribute* replacement);
/// Converts a SPIR-V struct member decoration into a number of AST /// Converts a SPIR-V struct member decoration into a number of AST
/// decorations. If the decoration is recognized but deliberately dropped, /// decorations. If the decoration is recognized but deliberately dropped,

View File

@ -3551,7 +3551,9 @@ Maybe<const ast::Attribute*> ParserImpl::attribute() {
} }
match(Token::Type::kComma); match(Token::Type::kComma);
return create<ast::LocationAttribute>(t.source(), val.value); return builder_.Location(t.source(),
create<ast::IntLiteralExpression>(
val.value, ast::IntLiteralExpression::Suffix::kNone));
}); });
} }

View File

@ -256,7 +256,10 @@ TEST_F(ParserImplTest, FunctionDecl_ReturnTypeAttributeList) {
ASSERT_EQ(ret_type_attributes.Length(), 1u); ASSERT_EQ(ret_type_attributes.Length(), 1u);
auto* loc = ret_type_attributes[0]->As<ast::LocationAttribute>(); auto* loc = ret_type_attributes[0]->As<ast::LocationAttribute>();
ASSERT_TRUE(loc != nullptr); ASSERT_TRUE(loc != nullptr);
EXPECT_EQ(loc->value, 1u); EXPECT_TRUE(loc->value->Is<ast::IntLiteralExpression>());
auto* exp = loc->value->As<ast::IntLiteralExpression>();
EXPECT_EQ(1u, exp->value);
auto* body = f->body; auto* body = f->body;
ASSERT_EQ(body->statements.Length(), 1u); ASSERT_EQ(body->statements.Length(), 1u);

View File

@ -54,9 +54,12 @@ TEST_F(ParserImplTest, FunctionHeader_AttributeReturnType) {
EXPECT_EQ(f->params.Length(), 0u); EXPECT_EQ(f->params.Length(), 0u);
EXPECT_TRUE(f->return_type->Is<ast::F32>()); EXPECT_TRUE(f->return_type->Is<ast::F32>());
ASSERT_EQ(f->return_type_attributes.Length(), 1u); ASSERT_EQ(f->return_type_attributes.Length(), 1u);
auto* loc = f->return_type_attributes[0]->As<ast::LocationAttribute>(); auto* loc = f->return_type_attributes[0]->As<ast::LocationAttribute>();
ASSERT_TRUE(loc != nullptr); ASSERT_TRUE(loc != nullptr);
EXPECT_EQ(loc->value, 1u); ASSERT_TRUE(loc->value->Is<ast::IntLiteralExpression>());
auto* exp = loc->value->As<ast::IntLiteralExpression>();
EXPECT_EQ(exp->value, 1u);
} }
TEST_F(ParserImplTest, FunctionHeader_InvariantReturnType) { TEST_F(ParserImplTest, FunctionHeader_InvariantReturnType) {

View File

@ -117,8 +117,12 @@ TEST_F(ParserImplTest, ParamList_Attributes) {
EXPECT_TRUE(e.value[1]->Is<ast::Parameter>()); EXPECT_TRUE(e.value[1]->Is<ast::Parameter>());
auto attrs_1 = e.value[1]->attributes; auto attrs_1 = e.value[1]->attributes;
ASSERT_EQ(attrs_1.Length(), 1u); ASSERT_EQ(attrs_1.Length(), 1u);
EXPECT_TRUE(attrs_1[0]->Is<ast::LocationAttribute>());
EXPECT_EQ(attrs_1[0]->As<ast::LocationAttribute>()->value, 1u); ASSERT_TRUE(attrs_1[0]->Is<ast::LocationAttribute>());
auto* attr = attrs_1[0]->As<ast::LocationAttribute>();
ASSERT_TRUE(attr->value->Is<ast::IntLiteralExpression>());
auto* loc = attr->value->As<ast::IntLiteralExpression>();
EXPECT_EQ(loc->value, 1u);
EXPECT_EQ(e.value[1]->source.range.begin.line, 1u); EXPECT_EQ(e.value[1]->source.range.begin.line, 1u);
EXPECT_EQ(e.value[1]->source.range.begin.column, 52u); EXPECT_EQ(e.value[1]->source.range.begin.column, 52u);

View File

@ -31,7 +31,13 @@ TEST_F(ParserImplTest, AttributeList_Parses) {
ASSERT_NE(attr_1, nullptr); ASSERT_NE(attr_1, nullptr);
ASSERT_TRUE(attr_0->Is<ast::LocationAttribute>()); ASSERT_TRUE(attr_0->Is<ast::LocationAttribute>());
EXPECT_EQ(attr_0->As<ast::LocationAttribute>()->value, 4u);
auto* loc = attr_0->As<ast::LocationAttribute>();
ASSERT_TRUE(loc->value->Is<ast::IntLiteralExpression>());
auto* exp = loc->value->As<ast::IntLiteralExpression>();
EXPECT_EQ(exp->value, 4u);
ASSERT_TRUE(attr_1->Is<ast::BuiltinAttribute>()); ASSERT_TRUE(attr_1->Is<ast::BuiltinAttribute>());
EXPECT_EQ(attr_1->As<ast::BuiltinAttribute>()->builtin, ast::BuiltinValue::kPosition); EXPECT_EQ(attr_1->As<ast::BuiltinAttribute>()->builtin, ast::BuiltinValue::kPosition);
} }

View File

@ -29,7 +29,9 @@ TEST_F(ParserImplTest, Attribute_Location) {
ASSERT_TRUE(var_attr->Is<ast::LocationAttribute>()); ASSERT_TRUE(var_attr->Is<ast::LocationAttribute>());
auto* loc = var_attr->As<ast::LocationAttribute>(); auto* loc = var_attr->As<ast::LocationAttribute>();
EXPECT_EQ(loc->value, 4u); ASSERT_TRUE(loc->value->Is<ast::IntLiteralExpression>());
auto* exp = loc->value->As<ast::IntLiteralExpression>();
EXPECT_EQ(exp->value, 4u);
} }
TEST_F(ParserImplTest, Attribute_Location_TrailingComma) { TEST_F(ParserImplTest, Attribute_Location_TrailingComma) {
@ -44,7 +46,9 @@ TEST_F(ParserImplTest, Attribute_Location_TrailingComma) {
ASSERT_TRUE(var_attr->Is<ast::LocationAttribute>()); ASSERT_TRUE(var_attr->Is<ast::LocationAttribute>());
auto* loc = var_attr->As<ast::LocationAttribute>(); auto* loc = var_attr->As<ast::LocationAttribute>();
EXPECT_EQ(loc->value, 4u); ASSERT_TRUE(loc->value->Is<ast::IntLiteralExpression>());
auto* exp = loc->value->As<ast::IntLiteralExpression>();
EXPECT_EQ(exp->value, 4u);
} }
TEST_F(ParserImplTest, Attribute_Location_MissingLeftParen) { TEST_F(ParserImplTest, Attribute_Location_MissingLeftParen) {

View File

@ -104,7 +104,7 @@ static utils::Vector<const ast::Attribute*, 2> createAttributes(const Source& so
case AttributeKind::kInvariant: case AttributeKind::kInvariant:
return {builder.Invariant(source)}; return {builder.Invariant(source)};
case AttributeKind::kLocation: case AttributeKind::kLocation:
return {builder.Location(source, 1)}; return {builder.Location(source, 1_a)};
case AttributeKind::kOffset: case AttributeKind::kOffset:
return {builder.create<ast::StructMemberOffsetAttribute>(source, 4u)}; return {builder.create<ast::StructMemberOffsetAttribute>(source, 4u)};
case AttributeKind::kSize: case AttributeKind::kSize:
@ -286,7 +286,7 @@ TEST_P(VertexShaderParameterAttributeTest, IsValid) {
auto& params = GetParam(); auto& params = GetParam();
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind); auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
if (params.kind != AttributeKind::kLocation) { if (params.kind != AttributeKind::kLocation) {
attrs.Push(Location(Source{{34, 56}}, 2)); attrs.Push(Location(Source{{34, 56}}, 2_a));
} }
auto* p = Param("a", ty.vec4<f32>(), attrs); auto* p = Param("a", ty.vec4<f32>(), attrs);
Func("vertex_main", utils::Vector{p}, ty.vec4<f32>(), Func("vertex_main", utils::Vector{p}, ty.vec4<f32>(),
@ -388,7 +388,7 @@ using FragmentShaderReturnTypeAttributeTest = TestWithParams;
TEST_P(FragmentShaderReturnTypeAttributeTest, IsValid) { TEST_P(FragmentShaderReturnTypeAttributeTest, IsValid) {
auto& params = GetParam(); auto& params = GetParam();
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind); auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
attrs.Push(Location(Source{{34, 56}}, 2)); attrs.Push(Location(Source{{34, 56}}, 2_a));
Func("frag_main", utils::Empty, ty.vec4<f32>(), Func("frag_main", utils::Empty, ty.vec4<f32>(),
utils::Vector{Return(Construct(ty.vec4<f32>()))}, utils::Vector{Return(Construct(ty.vec4<f32>()))},
utils::Vector{ utils::Vector{
@ -495,8 +495,8 @@ TEST_F(EntryPointParameterAttributeTest, DuplicateAttribute) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 2), Location(Source{{12, 34}}, 2_a),
Location(Source{{56, 78}}, 3), Location(Source{{56, 78}}, 3_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -531,8 +531,8 @@ TEST_F(EntryPointReturnTypeAttributeTest, DuplicateAttribute) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 2), Location(Source{{12, 34}}, 2_a),
Location(Source{{56, 78}}, 3), Location(Source{{56, 78}}, 3_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -1101,7 +1101,7 @@ TEST_F(InvariantAttributeTests, InvariantWithPosition) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
@ -1110,7 +1110,7 @@ TEST_F(InvariantAttributeTests, InvariantWithoutPosition) {
auto* param = Param("p", ty.vec4<f32>(), auto* param = Param("p", ty.vec4<f32>(),
utils::Vector{ utils::Vector{
Invariant(Source{{12, 34}}), Invariant(Source{{12, 34}}),
Location(0), Location(0_a),
}); });
Func("main", utils::Vector{param}, ty.vec4<f32>(), Func("main", utils::Vector{param}, ty.vec4<f32>(),
utils::Vector{ utils::Vector{
@ -1120,7 +1120,7 @@ TEST_F(InvariantAttributeTests, InvariantWithoutPosition) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
@ -1219,7 +1219,7 @@ TEST_P(InterpolateParameterTest, All) {
utils::Vector{ utils::Vector{
Param("a", ty.f32(), Param("a", ty.f32(),
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
Interpolate(Source{{12, 34}}, params.type, params.sampling), Interpolate(Source{{12, 34}}, params.type, params.sampling),
}), }),
}, },
@ -1245,7 +1245,7 @@ TEST_P(InterpolateParameterTest, IntegerScalar) {
utils::Vector{ utils::Vector{
Param("a", ty.i32(), Param("a", ty.i32(),
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
Interpolate(Source{{12, 34}}, params.type, params.sampling), Interpolate(Source{{12, 34}}, params.type, params.sampling),
}), }),
}, },
@ -1276,7 +1276,7 @@ TEST_P(InterpolateParameterTest, IntegerVector) {
utils::Vector{ utils::Vector{
Param("a", ty.vec4<u32>(), Param("a", ty.vec4<u32>(),
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
Interpolate(Source{{12, 34}}, params.type, params.sampling), Interpolate(Source{{12, 34}}, params.type, params.sampling),
}), }),
}, },
@ -1319,7 +1319,8 @@ INSTANTIATE_TEST_SUITE_P(
Params{ast::InterpolationType::kFlat, ast::InterpolationSampling::kSample, false})); Params{ast::InterpolationType::kFlat, ast::InterpolationSampling::kSample, false}));
TEST_F(InterpolateTest, FragmentInput_Integer_MissingFlatInterpolation) { TEST_F(InterpolateTest, FragmentInput_Integer_MissingFlatInterpolation) {
Func("main", utils::Vector{Param(Source{{12, 34}}, "a", ty.i32(), utils::Vector{Location(0)})}, Func("main",
utils::Vector{Param(Source{{12, 34}}, "a", ty.i32(), utils::Vector{Location(0_a)})},
ty.void_(), utils::Empty, ty.void_(), utils::Empty,
utils::Vector{ utils::Vector{
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
@ -1336,7 +1337,7 @@ TEST_F(InterpolateTest, VertexOutput_Integer_MissingFlatInterpolation) {
"S", "S",
utils::Vector{ utils::Vector{
Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}),
Member(Source{{12, 34}}, "u", ty.u32(), utils::Vector{Location(0)}), Member(Source{{12, 34}}, "u", ty.u32(), utils::Vector{Location(0_a)}),
}); });
Func("main", utils::Empty, ty.Of(s), Func("main", utils::Empty, ty.Of(s),
utils::Vector{ utils::Vector{

View File

@ -163,7 +163,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
@ -198,7 +198,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
@ -256,7 +256,7 @@ TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -301,7 +301,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragDepthNotF32_Struct_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -330,7 +330,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_Struct_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -372,7 +372,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleMaskIsNotU32_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(sample_mask) must be 'u32'"); EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(sample_mask) must be 'u32'");
@ -400,7 +400,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Struct_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -427,7 +427,7 @@ TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(sample_index) must be 'u32'"); EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(sample_index) must be 'u32'");
@ -453,7 +453,7 @@ TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(position) must be 'vec4<f32>'"); EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(position) must be 'vec4<f32>'");
@ -745,7 +745,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
@ -768,7 +768,7 @@ TEST_F(ResolverBuiltinsValidationTest, FrontFacingParamIsNotBool_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -797,7 +797,7 @@ TEST_F(ResolverBuiltinsValidationTest, FrontFacingMemberIsNotBool_Fail) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());

View File

@ -57,7 +57,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Location) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -110,7 +110,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Multiple) {
Stage(ast::PipelineStage::kVertex), Stage(ast::PipelineStage::kVertex),
}, },
utils::Vector{ utils::Vector{
Location(Source{{13, 43}}, 0), Location(Source{{13, 43}}, 0_a),
Builtin(Source{{14, 52}}, ast::BuiltinValue::kPosition), Builtin(Source{{14, 52}}, ast::BuiltinValue::kPosition),
}); });
@ -130,7 +130,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_Valid) {
// } // }
auto* output = Structure( auto* output = Structure(
"Output", utils::Vector{ "Output", utils::Vector{
Member("a", ty.f32(), utils::Vector{Location(0)}), Member("a", ty.f32(), utils::Vector{Location(0_a)}),
Member("b", ty.f32(), utils::Vector{Builtin(ast::BuiltinValue::kFragDepth)}), Member("b", ty.f32(), utils::Vector{Builtin(ast::BuiltinValue::kFragDepth)}),
}); });
Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output),
@ -156,7 +156,7 @@ TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_MemberMultipleAttribu
"Output", "Output",
utils::Vector{ utils::Vector{
Member("a", ty.f32(), Member("a", ty.f32(),
utils::Vector{Location(Source{{13, 43}}, 0), utils::Vector{Location(Source{{13, 43}}, 0_a),
Builtin(Source{{14, 52}}, ast::BuiltinValue::kFragDepth)}), Builtin(Source{{14, 52}}, ast::BuiltinValue::kFragDepth)}),
}); });
Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output),
@ -182,9 +182,9 @@ TEST_F(ResolverEntryPointValidationTest, ReturnType_Struct_MemberMissingAttribut
// fn main() -> Output { // fn main() -> Output {
// return Output(); // return Output();
// } // }
auto* output = auto* output = Structure(
Structure("Output", utils::Vector{ "Output", utils::Vector{
Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}),
Member(Source{{14, 52}}, "b", ty.f32(), {}), Member(Source{{14, 52}}, "b", ty.f32(), {}),
}); });
Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output),
@ -235,7 +235,7 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) {
// fn main(@location(0) param : f32) {} // fn main(@location(0) param : f32) {}
auto* param = Param("param", ty.f32(), auto* param = Param("param", ty.f32(),
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
Func(Source{{12, 34}}, "main", Func(Source{{12, 34}}, "main",
utils::Vector{ utils::Vector{
@ -271,7 +271,7 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) {
// fn main(@location(0) @builtin(sample_index) param : u32) {} // fn main(@location(0) @builtin(sample_index) param : u32) {}
auto* param = Param("param", ty.u32(), auto* param = Param("param", ty.u32(),
utils::Vector{ utils::Vector{
Location(Source{{13, 43}}, 0), Location(Source{{13, 43}}, 0_a),
Builtin(Source{{14, 52}}, ast::BuiltinValue::kSampleIndex), Builtin(Source{{14, 52}}, ast::BuiltinValue::kSampleIndex),
}); });
Func(Source{{12, 34}}, "main", Func(Source{{12, 34}}, "main",
@ -297,7 +297,7 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_Valid) {
// fn main(param : Input) {} // fn main(param : Input) {}
auto* input = Structure( auto* input = Structure(
"Input", utils::Vector{ "Input", utils::Vector{
Member("a", ty.f32(), utils::Vector{Location(0)}), Member("a", ty.f32(), utils::Vector{Location(0_a)}),
Member("b", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kSampleIndex)}), Member("b", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kSampleIndex)}),
}); });
auto* param = Param("param", ty.Of(input)); auto* param = Param("param", ty.Of(input));
@ -323,7 +323,7 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_MemberMultipleAttribut
"Input", "Input",
utils::Vector{ utils::Vector{
Member("a", ty.u32(), Member("a", ty.u32(),
utils::Vector{Location(Source{{13, 43}}, 0), utils::Vector{Location(Source{{13, 43}}, 0_a),
Builtin(Source{{14, 52}}, ast::BuiltinValue::kSampleIndex)}), Builtin(Source{{14, 52}}, ast::BuiltinValue::kSampleIndex)}),
}); });
auto* param = Param("param", ty.Of(input)); auto* param = Param("param", ty.Of(input));
@ -349,9 +349,9 @@ TEST_F(ResolverEntryPointValidationTest, Parameter_Struct_MemberMissingAttribute
// }; // };
// @fragment // @fragment
// fn main(param : Input) {} // fn main(param : Input) {}
auto* input = auto* input = Structure(
Structure("Input", utils::Vector{ "Input", utils::Vector{
Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}),
Member(Source{{14, 52}}, "b", ty.f32(), {}), Member(Source{{14, 52}}, "b", ty.f32(), {}),
}); });
auto* param = Param("param", ty.Of(input)); auto* param = Param("param", ty.Of(input));
@ -628,7 +628,7 @@ TEST_P(TypeValidationTest, BareInputs) {
auto* a = Param("a", params.create_ast_type(*this), auto* a = Param("a", params.create_ast_type(*this),
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
Flat(), Flat(),
}); });
Func(Source{{12, 34}}, "main", Func(Source{{12, 34}}, "main",
@ -657,9 +657,9 @@ TEST_P(TypeValidationTest, StructInputs) {
Enable(ast::Extension::kF16); Enable(ast::Extension::kF16);
auto* input = Structure( auto* input = Structure("Input", utils::Vector{
"Input", utils::Vector{ Member("a", params.create_ast_type(*this),
Member("a", params.create_ast_type(*this), utils::Vector{Location(0), Flat()}), utils::Vector{Location(0_a), Flat()}),
}); });
auto* a = Param("a", ty.Of(input), {}); auto* a = Param("a", ty.Of(input), {});
Func(Source{{12, 34}}, "main", Func(Source{{12, 34}}, "main",
@ -695,7 +695,7 @@ TEST_P(TypeValidationTest, BareOutputs) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
if (params.is_valid) { if (params.is_valid) {
@ -719,7 +719,7 @@ TEST_P(TypeValidationTest, StructOutputs) {
auto* output = Structure( auto* output = Structure(
"Output", utils::Vector{ "Output", utils::Vector{
Member("a", params.create_ast_type(*this), utils::Vector{Location(0)}), Member("a", params.create_ast_type(*this), utils::Vector{Location(0_a)}),
}); });
Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output),
utils::Vector{ utils::Vector{
@ -751,7 +751,7 @@ TEST_F(LocationAttributeTests, Pass) {
auto* p = Param(Source{{12, 34}}, "a", ty.i32(), auto* p = Param(Source{{12, 34}}, "a", ty.i32(),
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
Flat(), Flat(),
}); });
Func("frag_main", Func("frag_main",
@ -772,7 +772,7 @@ TEST_F(LocationAttributeTests, BadType_Input_bool) {
auto* p = Param(Source{{12, 34}}, "a", ty.bool_(), auto* p = Param(Source{{12, 34}}, "a", ty.bool_(),
utils::Vector{ utils::Vector{
Location(Source{{34, 56}}, 0), Location(Source{{34, 56}}, 0_a),
}); });
Func("frag_main", Func("frag_main",
utils::Vector{ utils::Vector{
@ -803,7 +803,7 @@ TEST_F(LocationAttributeTests, BadType_Output_Array) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(Source{{34, 56}}, 0), Location(Source{{34, 56}}, 0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -825,7 +825,7 @@ TEST_F(LocationAttributeTests, BadType_Input_Struct) {
}); });
auto* param = Param(Source{{12, 34}}, "param", ty.Of(input), auto* param = Param(Source{{12, 34}}, "param", ty.Of(input),
utils::Vector{ utils::Vector{
Location(Source{{13, 43}}, 0), Location(Source{{13, 43}}, 0_a),
}); });
Func(Source{{12, 34}}, "main", Func(Source{{12, 34}}, "main",
utils::Vector{ utils::Vector{
@ -853,9 +853,9 @@ TEST_F(LocationAttributeTests, BadType_Input_Struct_NestedStruct) {
// }; // };
// @fragment // @fragment
// fn main(param : Input) {} // fn main(param : Input) {}
auto* inner = auto* inner = Structure(
Structure("Inner", utils::Vector{ "Inner", utils::Vector{
Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}),
}); });
auto* input = Structure("Input", utils::Vector{ auto* input = Structure("Input", utils::Vector{
Member(Source{{14, 52}}, "a", ty.Of(inner)), Member(Source{{14, 52}}, "a", ty.Of(inner)),
@ -884,7 +884,7 @@ TEST_F(LocationAttributeTests, BadType_Input_Struct_RuntimeArray) {
// fn main(param : Input) {} // fn main(param : Input) {}
auto* input = Structure( auto* input = Structure(
"Input", utils::Vector{ "Input", utils::Vector{
Member(Source{{13, 43}}, "a", ty.array<f32>(), utils::Vector{Location(0)}), Member(Source{{13, 43}}, "a", ty.array<f32>(), utils::Vector{Location(0_a)}),
}); });
auto* param = Param("param", ty.Of(input)); auto* param = Param("param", ty.Of(input));
Func(Source{{12, 34}}, "main", Func(Source{{12, 34}}, "main",
@ -911,7 +911,7 @@ TEST_F(LocationAttributeTests, BadMemberType_Input) {
auto* m = Member(Source{{34, 56}}, "m", ty.array<i32>(), auto* m = Member(Source{{34, 56}}, "m", ty.array<i32>(),
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 0u), Location(Source{{12, 34}}, 0_u),
}); });
auto* s = Structure("S", utils::Vector{m}); auto* s = Structure("S", utils::Vector{m});
auto* p = Param("a", ty.Of(s)); auto* p = Param("a", ty.Of(s));
@ -939,7 +939,7 @@ TEST_F(LocationAttributeTests, BadMemberType_Output) {
// fn frag_main() -> S {} // fn frag_main() -> S {}
auto* m = Member(Source{{34, 56}}, "m", ty.atomic<i32>(), auto* m = Member(Source{{34, 56}}, "m", ty.atomic<i32>(),
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 0u), Location(Source{{12, 34}}, 0_u),
}); });
auto* s = Structure("S", utils::Vector{m}); auto* s = Structure("S", utils::Vector{m});
@ -965,7 +965,7 @@ TEST_F(LocationAttributeTests, BadMemberType_Unused) {
auto* m = Member(Source{{34, 56}}, "m", ty.mat3x2<f32>(), auto* m = Member(Source{{34, 56}}, "m", ty.mat3x2<f32>(),
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 0u), Location(Source{{12, 34}}, 0_u),
}); });
Structure("S", utils::Vector{m}); Structure("S", utils::Vector{m});
@ -988,7 +988,7 @@ TEST_F(LocationAttributeTests, ReturnType_Struct_Valid) {
// } // }
auto* output = Structure( auto* output = Structure(
"Output", utils::Vector{ "Output", utils::Vector{
Member("a", ty.f32(), utils::Vector{Location(0)}), Member("a", ty.f32(), utils::Vector{Location(0_a)}),
Member("b", ty.f32(), utils::Vector{Builtin(ast::BuiltinValue::kFragDepth)}), Member("b", ty.f32(), utils::Vector{Builtin(ast::BuiltinValue::kFragDepth)}),
}); });
Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output),
@ -1021,7 +1021,7 @@ TEST_F(LocationAttributeTests, ReturnType_Struct) {
Stage(ast::PipelineStage::kVertex), Stage(ast::PipelineStage::kVertex),
}, },
utils::Vector{ utils::Vector{
Location(Source{{13, 43}}, 0), Location(Source{{13, 43}}, 0_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -1041,9 +1041,9 @@ TEST_F(LocationAttributeTests, ReturnType_Struct_NestedStruct) {
// }; // };
// @fragment // @fragment
// fn main() -> Output { return Output(); } // fn main() -> Output { return Output(); }
auto* inner = auto* inner = Structure(
Structure("Inner", utils::Vector{ "Inner", utils::Vector{
Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0)}), Member(Source{{13, 43}}, "a", ty.f32(), utils::Vector{Location(0_a)}),
}); });
auto* output = Structure("Output", utils::Vector{ auto* output = Structure("Output", utils::Vector{
Member(Source{{14, 52}}, "a", ty.Of(inner)), Member(Source{{14, 52}}, "a", ty.Of(inner)),
@ -1072,7 +1072,7 @@ TEST_F(LocationAttributeTests, ReturnType_Struct_RuntimeArray) {
// } // }
auto* output = Structure("Output", utils::Vector{ auto* output = Structure("Output", utils::Vector{
Member(Source{{13, 43}}, "a", ty.array<f32>(), Member(Source{{13, 43}}, "a", ty.array<f32>(),
utils::Vector{Location(Source{{12, 34}}, 0)}), utils::Vector{Location(Source{{12, 34}}, 0_a)}),
}); });
Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output), Func(Source{{12, 34}}, "main", utils::Empty, ty.Of(output),
utils::Vector{ utils::Vector{
@ -1100,7 +1100,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Input) {
create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1_i)), create<ast::WorkgroupAttribute>(Source{{12, 34}}, Expr(1_i)),
}, },
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 1), Location(Source{{12, 34}}, 1_a),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -1110,7 +1110,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Input) {
TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) { TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) {
auto* input = Param("input", ty.i32(), auto* input = Param("input", ty.i32(),
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 0u), Location(Source{{12, 34}}, 0_u),
}); });
Func("main", utils::Vector{input}, ty.void_(), utils::Empty, Func("main", utils::Vector{input}, ty.void_(), utils::Empty,
utils::Vector{ utils::Vector{
@ -1125,7 +1125,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) {
TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) { TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) {
auto* m = Member("m", ty.i32(), auto* m = Member("m", ty.i32(),
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 0u), Location(Source{{12, 34}}, 0_u),
}); });
auto* s = Structure("S", utils::Vector{m}); auto* s = Structure("S", utils::Vector{m});
Func(Source{{56, 78}}, "main", utils::Empty, ty.Of(s), Func(Source{{56, 78}}, "main", utils::Empty, ty.Of(s),
@ -1146,7 +1146,7 @@ TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) {
TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Input) { TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Input) {
auto* m = Member("m", ty.i32(), auto* m = Member("m", ty.i32(),
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 0u), Location(Source{{12, 34}}, 0_u),
}); });
auto* s = Structure("S", utils::Vector{m}); auto* s = Structure("S", utils::Vector{m});
auto* input = Param("input", ty.Of(s)); auto* input = Param("input", ty.Of(s));
@ -1168,11 +1168,11 @@ TEST_F(LocationAttributeTests, Duplicate_input) {
// @location(1) param_b : f32) {} // @location(1) param_b : f32) {}
auto* param_a = Param("param_a", ty.f32(), auto* param_a = Param("param_a", ty.f32(),
utils::Vector{ utils::Vector{
Location(1), Location(1_a),
}); });
auto* param_b = Param("param_b", ty.f32(), auto* param_b = Param("param_b", ty.f32(),
utils::Vector{ utils::Vector{
Location(Source{{12, 34}}, 1), Location(Source{{12, 34}}, 1_a),
}); });
Func(Source{{12, 34}}, "main", Func(Source{{12, 34}}, "main",
utils::Vector{ utils::Vector{
@ -1198,11 +1198,11 @@ TEST_F(LocationAttributeTests, Duplicate_struct) {
// @fragment // @fragment
// fn main(param_a : InputA, param_b : InputB) {} // fn main(param_a : InputA, param_b : InputB) {}
auto* input_a = Structure("InputA", utils::Vector{ auto* input_a = Structure("InputA", utils::Vector{
Member("a", ty.f32(), utils::Vector{Location(1)}), Member("a", ty.f32(), utils::Vector{Location(1_a)}),
}); });
auto* input_b = auto* input_b = Structure(
Structure("InputB", utils::Vector{ "InputB", utils::Vector{
Member("a", ty.f32(), utils::Vector{Location(Source{{34, 56}}, 1)}), Member("a", ty.f32(), utils::Vector{Location(Source{{34, 56}}, 1_a)}),
}); });
auto* param_a = Param("param_a", ty.Of(input_a)); auto* param_a = Param("param_a", ty.Of(input_a));
auto* param_b = Param("param_b", ty.Of(input_b)); auto* param_b = Param("param_b", ty.Of(input_b));

View File

@ -640,7 +640,17 @@ sem::Variable* Resolver::Var(const ast::Var* var, bool is_global) {
std::optional<uint32_t> location; std::optional<uint32_t> location;
if (auto* attr = ast::GetAttribute<ast::LocationAttribute>(var->attributes)) { if (auto* attr = ast::GetAttribute<ast::LocationAttribute>(var->attributes)) {
location = attr->value; auto* materialize = Materialize(Expression(attr->value));
if (!materialize) {
return nullptr;
}
auto* c = materialize->ConstantValue();
if (!c) {
// TODO(crbug.com/tint/1633): Add error message about invalid materialization
// when location can be an expression.
return nullptr;
}
location = c->As<uint32_t>();
} }
sem = builder_->create<sem::GlobalVariable>( sem = builder_->create<sem::GlobalVariable>(
@ -725,7 +735,17 @@ sem::Parameter* Resolver::Parameter(const ast::Parameter* param, uint32_t index)
std::optional<uint32_t> location; std::optional<uint32_t> location;
if (auto* l = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) { if (auto* l = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) {
location = l->value; auto* materialize = Materialize(Expression(l->value));
if (!materialize) {
return nullptr;
}
auto* c = materialize->ConstantValue();
if (!c) {
// TODO(crbug.com/tint/1633): Add error message about invalid materialization when
// location can be an expression.
return nullptr;
}
location = c->As<uint32_t>();
} }
auto* sem = builder_->create<sem::Parameter>( auto* sem = builder_->create<sem::Parameter>(
@ -924,7 +944,17 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
Mark(attr); Mark(attr);
if (auto* a = attr->As<ast::LocationAttribute>()) { if (auto* a = attr->As<ast::LocationAttribute>()) {
return_location = a->value; auto* materialize = Materialize(Expression(a->value));
if (!materialize) {
return nullptr;
}
auto* c = materialize->ConstantValue();
if (!c) {
// TODO(crbug.com/tint/1633): Add error message about invalid materialization when
// location can be an expression.
return nullptr;
}
return_location = c->As<uint32_t>();
} }
} }
if (!validator_.NoDuplicateAttributes(decl->attributes)) { if (!validator_.NoDuplicateAttributes(decl->attributes)) {
@ -2808,7 +2838,17 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
size = s->size; size = s->size;
has_size_attr = true; has_size_attr = true;
} else if (auto* l = attr->As<ast::LocationAttribute>()) { } else if (auto* l = attr->As<ast::LocationAttribute>()) {
location = l->value; auto* materialize = Materialize(Expression(l->value));
if (!materialize) {
return nullptr;
}
auto* c = materialize->ConstantValue();
if (!c) {
// TODO(crbug.com/tint/1633): Add error message about invalid materialization
// when location can be an expression.
return nullptr;
}
location = c->As<uint32_t>();
} }
} }

View File

@ -774,9 +774,9 @@ TEST_F(ResolverTest, Function_Parameters) {
} }
TEST_F(ResolverTest, Function_Parameters_Locations) { TEST_F(ResolverTest, Function_Parameters_Locations) {
auto* param_a = Param("a", ty.f32(), utils::Vector{Location(3)}); auto* param_a = Param("a", ty.f32(), utils::Vector{Location(3_a)});
auto* param_b = Param("b", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kVertexIndex)}); auto* param_b = Param("b", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kVertexIndex)});
auto* param_c = Param("c", ty.u32(), utils::Vector{Location(1)}); auto* param_c = Param("c", ty.u32(), utils::Vector{Location(1_a)});
GlobalVar("my_vec", ty.vec4<f32>(), ast::StorageClass::kPrivate); GlobalVar("my_vec", ty.vec4<f32>(), ast::StorageClass::kPrivate);
auto* func = Func("my_func", auto* func = Func("my_func",
@ -809,7 +809,7 @@ TEST_F(ResolverTest, Function_Parameters_Locations) {
TEST_F(ResolverTest, Function_GlobalVariable_Location) { TEST_F(ResolverTest, Function_GlobalVariable_Location) {
auto* var = GlobalVar( auto* var = GlobalVar(
"my_vec", ty.vec4<f32>(), ast::StorageClass::kIn, "my_vec", ty.vec4<f32>(), ast::StorageClass::kIn,
utils::Vector{Location(3), Disable(ast::DisabledValidation::kIgnoreStorageClass)}); utils::Vector{Location(3_a), Disable(ast::DisabledValidation::kIgnoreStorageClass)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -856,7 +856,7 @@ TEST_F(ResolverTest, Function_ReturnType_Location) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(2), Location(2_a),
}); });
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();

View File

@ -29,7 +29,7 @@ namespace {
using ResolverPipelineStageUseTest = ResolverTest; using ResolverPipelineStageUseTest = ResolverTest;
TEST_F(ResolverPipelineStageUseTest, UnusedStruct) { TEST_F(ResolverPipelineStageUseTest, UnusedStruct) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
@ -39,7 +39,7 @@ TEST_F(ResolverPipelineStageUseTest, UnusedStruct) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointParam) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointParam) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
Func("foo", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty, utils::Empty); Func("foo", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty, utils::Empty);
@ -51,7 +51,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointParam) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointReturnType) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointReturnType) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
Func("foo", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))}, Func("foo", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))},
utils::Empty); utils::Empty);
@ -64,7 +64,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsNonEntryPointReturnType) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderParam) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderParam) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.vec4<f32>(), Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.vec4<f32>(),
utils::Vector{Return(Construct(ty.vec4<f32>()))}, utils::Vector{Return(Construct(ty.vec4<f32>()))},
@ -96,7 +96,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderReturnType) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty, Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty,
utils::Vector{Stage(ast::PipelineStage::kFragment)}); utils::Vector{Stage(ast::PipelineStage::kFragment)});
@ -110,7 +110,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderParam) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderReturnType) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsFragmentShaderReturnType) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
Func("main", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))}, Func("main", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))},
utils::Vector{Stage(ast::PipelineStage::kFragment)}); utils::Vector{Stage(ast::PipelineStage::kFragment)});
@ -160,7 +160,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedMultipleStages) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
auto* s_alias = Alias("S_alias", ty.Of(s)); auto* s_alias = Alias("S_alias", ty.Of(s));
Func("main", utils::Vector{Param("param", ty.Of(s_alias))}, ty.void_(), utils::Empty, Func("main", utils::Vector{Param("param", ty.Of(s_alias))}, ty.void_(), utils::Empty,
@ -175,7 +175,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamViaAlias) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamLocationSet) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamLocationSet) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3_a)})});
Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty, Func("main", utils::Vector{Param("param", ty.Of(s))}, ty.void_(), utils::Empty,
utils::Vector{Stage(ast::PipelineStage::kFragment)}); utils::Vector{Stage(ast::PipelineStage::kFragment)});
@ -189,7 +189,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderParamLocationSet) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeViaAlias) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeViaAlias) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(0_a)})});
auto* s_alias = Alias("S_alias", ty.Of(s)); auto* s_alias = Alias("S_alias", ty.Of(s));
Func("main", utils::Empty, ty.Of(s_alias), Func("main", utils::Empty, ty.Of(s_alias),
@ -205,7 +205,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeViaAlias) {
} }
TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeLocationSet) { TEST_F(ResolverPipelineStageUseTest, StructUsedAsShaderReturnTypeLocationSet) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3)})}); auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{Location(3_a)})});
Func("main", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))}, Func("main", utils::Empty, ty.Of(s), utils::Vector{Return(Construct(ty.Of(s), Expr(0_f)))},
utils::Vector{Stage(ast::PipelineStage::kFragment)}); utils::Vector{Stage(ast::PipelineStage::kFragment)});

View File

@ -121,12 +121,13 @@ bool IsValidStorageTextureTexelFormat(ast::TexelFormat format) {
} }
// Helper to stringify a pipeline IO attribute. // Helper to stringify a pipeline IO attribute.
std::string attr_to_str(const ast::Attribute* attr) { std::string attr_to_str(const ast::Attribute* attr,
std::optional<uint32_t> location = std::nullopt) {
std::stringstream str; std::stringstream str;
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
str << "builtin(" << builtin->builtin << ")"; str << "builtin(" << builtin->builtin << ")";
} else if (auto* location = attr->As<ast::LocationAttribute>()) { } else if (attr->Is<ast::LocationAttribute>()) {
str << "location(" << location->value << ")"; str << "location(" << location.value() << ")";
} }
return str.str(); return str.str();
} }
@ -1123,7 +1124,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
auto validate_entry_point_attributes_inner = [&](utils::VectorRef<const ast::Attribute*> attrs, auto validate_entry_point_attributes_inner = [&](utils::VectorRef<const ast::Attribute*> attrs,
const sem::Type* ty, Source source, const sem::Type* ty, Source source,
ParamOrRetType param_or_ret, ParamOrRetType param_or_ret,
bool is_struct_member) { bool is_struct_member,
std::optional<uint32_t> location) {
// Temporally forbid using f16 types in entry point IO. // Temporally forbid using f16 types in entry point IO.
// TODO(tint:1473, tint:1502): Remove this error after f16 is supported in entry point // TODO(tint:1473, tint:1502): Remove this error after f16 is supported in entry point
// IO. // IO.
@ -1143,7 +1145,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
if (pipeline_io_attribute) { if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", attr->source); AddError("multiple entry point IO attributes", attr->source);
AddNote("previously consumed " + attr_to_str(pipeline_io_attribute), AddNote("previously consumed " + attr_to_str(pipeline_io_attribute, location),
pipeline_io_attribute->source); pipeline_io_attribute->source);
return false; return false;
} }
@ -1162,7 +1164,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
return false; return false;
} }
builtins.emplace(builtin->builtin); builtins.emplace(builtin->builtin);
} else if (auto* location = attr->As<ast::LocationAttribute>()) { } else if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
if (pipeline_io_attribute) { if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", attr->source); AddError("multiple entry point IO attributes", attr->source);
AddNote("previously consumed " + attr_to_str(pipeline_io_attribute), AddNote("previously consumed " + attr_to_str(pipeline_io_attribute),
@ -1173,7 +1175,13 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
bool is_input = param_or_ret == ParamOrRetType::kParameter; bool is_input = param_or_ret == ParamOrRetType::kParameter;
if (!LocationAttribute(location, ty, locations, stage, source, is_input)) { if (!location.has_value()) {
TINT_ICE(Resolver, diagnostics_) << "Location has no value";
return false;
}
if (!LocationAttribute(loc_attr, location.value(), ty, locations, stage, source,
is_input)) {
return false; return false;
} }
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) { } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
@ -1266,9 +1274,10 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
// Outer lambda for validating the entry point attributes for a type. // Outer lambda for validating the entry point attributes for a type.
auto validate_entry_point_attributes = [&](utils::VectorRef<const ast::Attribute*> attrs, auto validate_entry_point_attributes = [&](utils::VectorRef<const ast::Attribute*> attrs,
const sem::Type* ty, Source source, const sem::Type* ty, Source source,
ParamOrRetType param_or_ret) { ParamOrRetType param_or_ret,
std::optional<uint32_t> location) {
if (!validate_entry_point_attributes_inner(attrs, ty, source, param_or_ret, if (!validate_entry_point_attributes_inner(attrs, ty, source, param_or_ret,
/*is_struct_member*/ false)) { /*is_struct_member*/ false, location)) {
return false; return false;
} }
@ -1277,7 +1286,7 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
if (!validate_entry_point_attributes_inner( if (!validate_entry_point_attributes_inner(
member->Declaration()->attributes, member->Type(), member->Declaration()->attributes, member->Type(),
member->Declaration()->source, param_or_ret, member->Declaration()->source, param_or_ret,
/*is_struct_member*/ true)) { /*is_struct_member*/ true, member->Location())) {
AddNote("while analysing entry point '" + symbols_.NameFor(decl->symbol) + "'", AddNote("while analysing entry point '" + symbols_.NameFor(decl->symbol) + "'",
decl->source); decl->source);
return false; return false;
@ -1291,7 +1300,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
for (auto* param : func->Parameters()) { for (auto* param : func->Parameters()) {
auto* param_decl = param->Declaration(); auto* param_decl = param->Declaration();
if (!validate_entry_point_attributes(param_decl->attributes, param->Type(), if (!validate_entry_point_attributes(param_decl->attributes, param->Type(),
param_decl->source, ParamOrRetType::kParameter)) { param_decl->source, ParamOrRetType::kParameter,
param->Location())) {
return false; return false;
} }
} }
@ -1304,7 +1314,8 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
if (!func->ReturnType()->Is<sem::Void>()) { if (!func->ReturnType()->Is<sem::Void>()) {
if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(), if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(),
decl->source, ParamOrRetType::kReturnType)) { decl->source, ParamOrRetType::kReturnType,
func->ReturnLocation())) {
return false; return false;
} }
} }
@ -2177,8 +2188,9 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
invariant_attribute = invariant; invariant_attribute = invariant;
} else if (auto* location = attr->As<ast::LocationAttribute>()) { } else if (auto* location = attr->As<ast::LocationAttribute>()) {
has_location = true; has_location = true;
if (!LocationAttribute(location, member->Type(), locations, stage, TINT_ASSERT(Resolver, member->Location().has_value());
member->Declaration()->source)) { if (!LocationAttribute(location, member->Location().value(), member->Type(),
locations, stage, member->Declaration()->source)) {
return false; return false;
} }
} else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { } else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
@ -2220,7 +2232,8 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
return true; return true;
} }
bool Validator::LocationAttribute(const ast::LocationAttribute* location, bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location,
const sem::Type* type, const sem::Type* type,
std::unordered_set<uint32_t>& locations, std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage, ast::PipelineStage stage,
@ -2228,7 +2241,7 @@ bool Validator::LocationAttribute(const ast::LocationAttribute* location,
const bool is_input) const { const bool is_input) const {
std::string inputs_or_output = is_input ? "inputs" : "output"; std::string inputs_or_output = is_input ? "inputs" : "output";
if (stage == ast::PipelineStage::kCompute) { if (stage == ast::PipelineStage::kCompute) {
AddError("attribute is not valid for compute shader " + inputs_or_output, location->source); AddError("attribute is not valid for compute shader " + inputs_or_output, loc_attr->source);
return false; return false;
} }
@ -2239,15 +2252,16 @@ bool Validator::LocationAttribute(const ast::LocationAttribute* location,
AddNote( AddNote(
"'location' attribute must only be applied to declarations of " "'location' attribute must only be applied to declarations of "
"numeric scalar or numeric vector type", "numeric scalar or numeric vector type",
location->source); loc_attr->source);
return false; return false;
} }
if (locations.count(location->value)) { if (locations.count(location)) {
AddError(attr_to_str(location) + " attribute appears multiple times", location->source); AddError(attr_to_str(loc_attr, location) + " attribute appears multiple times",
loc_attr->source);
return false; return false;
} }
locations.emplace(location->value); locations.emplace(location);
return true; return true;
} }

View File

@ -273,14 +273,16 @@ class Validator {
bool LocalVariable(const sem::Variable* v) const; bool LocalVariable(const sem::Variable* v) const;
/// Validates a location attribute /// Validates a location attribute
/// @param location the location attribute to validate /// @param loc_attr the location attribute to validate
/// @param location the location value
/// @param type the variable type /// @param type the variable type
/// @param locations the set of locations in the module /// @param locations the set of locations in the module
/// @param stage the current pipeline stage /// @param stage the current pipeline stage
/// @param source the source of the attribute /// @param source the source of the attribute
/// @param is_input true if this is an input variable /// @param is_input true if this is an input variable
/// @returns true on success, false otherwise. /// @returns true on success, false otherwise.
bool LocationAttribute(const ast::LocationAttribute* location, bool LocationAttribute(const ast::LocationAttribute* loc_attr,
uint32_t location,
const sem::Type* type, const sem::Type* type,
std::unordered_set<uint32_t>& locations, std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage, ast::PipelineStage stage,

View File

@ -19,6 +19,7 @@
namespace tint::sem { namespace tint::sem {
namespace { namespace {
using namespace tint::number_suffixes; // NOLINT
using StructTest = TestHelper; using StructTest = TestHelper;
TEST_F(StructTest, Creation) { TEST_F(StructTest, Creation) {
@ -107,7 +108,7 @@ TEST_F(StructTest, Layout) {
TEST_F(StructTest, Location) { TEST_F(StructTest, Location) {
auto* st = Structure("st", utils::Vector{ auto* st = Structure("st", utils::Vector{
Member("a", ty.i32(), utils::Vector{Location(1u)}), Member("a", ty.i32(), utils::Vector{Location(1_u)}),
Member("b", ty.u32()), Member("b", ty.u32()),
}); });

View File

@ -37,21 +37,32 @@ CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
namespace { namespace {
// Comparison function used to reorder struct members such that all members with /// Info for a struct member
// location attributes appear first (ordered by location slot), followed by struct MemberInfo {
// those with builtin attributes. /// The struct member item
bool StructMemberComparator(const ast::StructMember* a, const ast::StructMember* b) { const ast::StructMember* member;
auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a->attributes); /// The struct member location if provided
auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b->attributes); std::optional<uint32_t> location;
auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a->attributes); };
auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b->attributes);
/// Comparison function used to reorder struct members such that all members with
/// location attributes appear first (ordered by location slot), followed by
/// those with builtin attributes.
/// @param a a struct member
/// @param b another struct member
/// @returns true if a comes before b
bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) {
auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a.member->attributes);
auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b.member->attributes);
auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a.member->attributes);
auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b.member->attributes);
if (a_loc) { if (a_loc) {
if (!b_loc) { if (!b_loc) {
// `a` has location attribute and `b` does not: `a` goes first. // `a` has location attribute and `b` does not: `a` goes first.
return true; return true;
} }
// Both have location attributes: smallest goes first. // Both have location attributes: smallest goes first.
return a_loc->value < b_loc->value; return a.location < b.location;
} else { } else {
if (b_loc) { if (b_loc) {
// `b` has location attribute and `a` does not: `b` goes first. // `b` has location attribute and `a` does not: `b` goes first.
@ -88,6 +99,8 @@ struct CanonicalizeEntryPointIO::State {
utils::Vector<const ast::Attribute*, 2> attributes; utils::Vector<const ast::Attribute*, 2> attributes;
/// The value itself. /// The value itself.
const ast::Expression* value; const ast::Expression* value;
/// The output location.
std::optional<uint32_t> location;
}; };
/// The clone context. /// The clone context.
@ -101,14 +114,15 @@ struct CanonicalizeEntryPointIO::State {
/// The new entry point wrapper function's parameters. /// The new entry point wrapper function's parameters.
utils::Vector<const ast::Parameter*, 8> wrapper_ep_parameters; utils::Vector<const ast::Parameter*, 8> wrapper_ep_parameters;
/// The members of the wrapper function's struct parameter. /// The members of the wrapper function's struct parameter.
utils::Vector<const ast::StructMember*, 8> wrapper_struct_param_members; utils::Vector<MemberInfo, 8> wrapper_struct_param_members;
/// The name of the wrapper function's struct parameter. /// The name of the wrapper function's struct parameter.
Symbol wrapper_struct_param_name; Symbol wrapper_struct_param_name;
/// The parameters that will be passed to the original function. /// The parameters that will be passed to the original function.
utils::Vector<const ast::Expression*, 8> inner_call_parameters; utils::Vector<const ast::Expression*, 8> inner_call_parameters;
/// The members of the wrapper function's struct return type. /// The members of the wrapper function's struct return type.
utils::Vector<const ast::StructMember*, 8> wrapper_struct_output_members; utils::Vector<MemberInfo, 8> wrapper_struct_output_members;
/// The wrapper function output values. /// The wrapper function output values.
utils::Vector<OutputValue, 8> wrapper_output_values; utils::Vector<OutputValue, 8> wrapper_output_values;
/// The body of the wrapper function. /// The body of the wrapper function.
@ -153,10 +167,12 @@ struct CanonicalizeEntryPointIO::State {
/// Add a shader input to the entry point. /// Add a shader input to the entry point.
/// @param name the name of the shader input /// @param name the name of the shader input
/// @param type the type of the shader input /// @param type the type of the shader input
/// @param location the location if provided
/// @param attributes the attributes to apply to the shader input /// @param attributes the attributes to apply to the shader input
/// @returns an expression which evaluates to the value of the shader input /// @returns an expression which evaluates to the value of the shader input
const ast::Expression* AddInput(std::string name, const ast::Expression* AddInput(std::string name,
const sem::Type* type, const sem::Type* type,
std::optional<uint32_t> location,
utils::Vector<const ast::Attribute*, 8> attributes) { utils::Vector<const ast::Attribute*, 8> attributes) {
auto* ast_type = CreateASTTypeFor(ctx, type); auto* ast_type = CreateASTTypeFor(ctx, type);
if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) { if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
@ -214,7 +230,7 @@ struct CanonicalizeEntryPointIO::State {
Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name) Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name)
: ctx.dst->Symbols().New(name); : ctx.dst->Symbols().New(name);
wrapper_struct_param_members.Push( wrapper_struct_param_members.Push(
ctx.dst->Member(symbol, ast_type, std::move(attributes))); {ctx.dst->Member(symbol, ast_type, std::move(attributes)), location});
return ctx.dst->MemberAccessor(InputStructSymbol(), symbol); return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
} }
} }
@ -222,10 +238,12 @@ struct CanonicalizeEntryPointIO::State {
/// Add a shader output to the entry point. /// Add a shader output to the entry point.
/// @param name the name of the shader output /// @param name the name of the shader output
/// @param type the type of the shader output /// @param type the type of the shader output
/// @param location the location if provided
/// @param attributes the attributes to apply to the shader output /// @param attributes the attributes to apply to the shader output
/// @param value the value of the shader output /// @param value the value of the shader output
void AddOutput(std::string name, void AddOutput(std::string name,
const sem::Type* type, const sem::Type* type,
std::optional<uint32_t> location,
utils::Vector<const ast::Attribute*, 8> attributes, utils::Vector<const ast::Attribute*, 8> attributes,
const ast::Expression* value) { const ast::Expression* value) {
// Vulkan requires that integer user-defined vertex outputs are always decorated with // Vulkan requires that integer user-defined vertex outputs are always decorated with
@ -256,6 +274,7 @@ struct CanonicalizeEntryPointIO::State {
output.type = CreateASTTypeFor(ctx, type); output.type = CreateASTTypeFor(ctx, type);
output.attributes = std::move(attributes); output.attributes = std::move(attributes);
output.value = value; output.value = value;
output.location = location;
wrapper_output_values.Push(output); wrapper_output_values.Push(output);
} }
@ -280,7 +299,7 @@ struct CanonicalizeEntryPointIO::State {
} }
auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol); auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol);
auto* input_expr = AddInput(name, param->Type(), std::move(attributes)); auto* input_expr = AddInput(name, param->Type(), param->Location(), std::move(attributes));
inner_call_parameters.Push(input_expr); inner_call_parameters.Push(input_expr);
} }
@ -308,7 +327,8 @@ struct CanonicalizeEntryPointIO::State {
auto name = ctx.src->Symbols().NameFor(member_ast->symbol); auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate); auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
auto* input_expr = AddInput(name, member->Type(), std::move(attributes)); auto* input_expr =
AddInput(name, member->Type(), member->Location(), std::move(attributes));
inner_struct_values.Push(input_expr); inner_struct_values.Push(input_expr);
} }
@ -337,7 +357,7 @@ struct CanonicalizeEntryPointIO::State {
auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate); auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
// Extract the original structure member. // Extract the original structure member.
AddOutput(name, member->Type(), std::move(attributes), AddOutput(name, member->Type(), member->Location(), std::move(attributes),
ctx.dst->MemberAccessor(original_result, name)); ctx.dst->MemberAccessor(original_result, name));
} }
} else if (!inner_ret_type->Is<sem::Void>()) { } else if (!inner_ret_type->Is<sem::Void>()) {
@ -345,8 +365,8 @@ struct CanonicalizeEntryPointIO::State {
CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate); CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate);
// Propagate the non-struct return value as is. // Propagate the non-struct return value as is.
AddOutput("value", func_sem->ReturnType(), std::move(attributes), AddOutput("value", func_sem->ReturnType(), func_sem->ReturnLocation(),
ctx.dst->Expr(original_result)); std::move(attributes), ctx.dst->Expr(original_result));
} }
} }
@ -365,7 +385,7 @@ struct CanonicalizeEntryPointIO::State {
// No existing sample mask builtin was found, so create a new output value // No existing sample mask builtin was found, so create a new output value
// using the fixed sample mask. // using the fixed sample mask.
AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(), AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(), std::nullopt,
{ctx.dst->Builtin(ast::BuiltinValue::kSampleMask)}, {ctx.dst->Builtin(ast::BuiltinValue::kSampleMask)},
ctx.dst->Expr(u32(cfg.fixed_sample_mask))); ctx.dst->Expr(u32(cfg.fixed_sample_mask)));
} }
@ -373,7 +393,7 @@ struct CanonicalizeEntryPointIO::State {
/// Add a point size builtin to the wrapper function output. /// Add a point size builtin to the wrapper function output.
void AddVertexPointSize() { void AddVertexPointSize() {
// Create a new output value and assign it a literal 1.0 value. // Create a new output value and assign it a literal 1.0 value.
AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(), AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(), std::nullopt,
{ctx.dst->Builtin(ast::BuiltinValue::kPointSize)}, ctx.dst->Expr(1_f)); {ctx.dst->Builtin(ast::BuiltinValue::kPointSize)}, ctx.dst->Expr(1_f));
} }
@ -392,10 +412,14 @@ struct CanonicalizeEntryPointIO::State {
std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(), std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
StructMemberComparator); StructMemberComparator);
utils::Vector<const ast::StructMember*, 8> members;
for (auto& mem : wrapper_struct_param_members) {
members.Push(mem.member);
}
// Create the new struct type. // Create the new struct type.
auto struct_name = ctx.dst->Sym(); auto struct_name = ctx.dst->Sym();
auto* in_struct = auto* in_struct = ctx.dst->create<ast::Struct>(struct_name, members, utils::Empty);
ctx.dst->create<ast::Struct>(struct_name, wrapper_struct_param_members, utils::Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct); ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
// Create a new function parameter using this struct type. // Create a new function parameter using this struct type.
@ -423,7 +447,8 @@ struct CanonicalizeEntryPointIO::State {
member_names.insert(ctx.dst->Symbols().NameFor(name)); member_names.insert(ctx.dst->Symbols().NameFor(name));
wrapper_struct_output_members.Push( wrapper_struct_output_members.Push(
ctx.dst->Member(name, outval.type, std::move(outval.attributes))); {ctx.dst->Member(name, outval.type, std::move(outval.attributes)),
outval.location});
assignments.Push( assignments.Push(
ctx.dst->Assign(ctx.dst->MemberAccessor(wrapper_result, name), outval.value)); ctx.dst->Assign(ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
} }
@ -432,9 +457,13 @@ struct CanonicalizeEntryPointIO::State {
std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(), std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
StructMemberComparator); StructMemberComparator);
utils::Vector<const ast::StructMember*, 8> members;
for (auto& mem : wrapper_struct_output_members) {
members.Push(mem.member);
}
// Create the new struct type. // Create the new struct type.
auto* out_struct = ctx.dst->create<ast::Struct>( auto* out_struct = ctx.dst->create<ast::Struct>(ctx.dst->Sym(), members, utils::Empty);
ctx.dst->Sym(), wrapper_struct_output_members, utils::Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct); ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
// Create the output struct object, assign its members, and return it. // Create the output struct object, assign its members, and return it.

View File

@ -692,7 +692,7 @@ struct State {
/// @param func the entry point function /// @param func the entry point function
/// @param param the parameter to process /// @param param the parameter to process
void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) { void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) {
if (auto* location = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) { if (ast::HasAttribute<ast::LocationAttribute>(param->attributes)) {
// Create a function-scope variable to replace the parameter. // Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol); auto func_var_sym = ctx.Clone(param->symbol);
auto* func_var_type = ctx.Clone(param->type); auto* func_var_type = ctx.Clone(param->type);
@ -701,8 +701,15 @@ struct State {
// Capture mapping from location to the new variable. // Capture mapping from location to the new variable.
LocationInfo info; LocationInfo info;
info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); }; info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
info.type = ctx.src->Sem().Get(param)->Type();
location_info[location->value] = info; auto* sem = ctx.src->Sem().Get<sem::Parameter>(param);
info.type = sem->Type();
if (!sem->Location().has_value()) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value";
return;
}
location_info[sem->Location().value()] = info;
} else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) { } else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
// Check for existing vertex_index and instance_index builtins. // Check for existing vertex_index and instance_index builtins.
if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { if (builtin->builtin == ast::BuiltinValue::kVertexIndex) {
@ -742,12 +749,16 @@ struct State {
return ctx.dst->MemberAccessor(param_sym, member_sym); return ctx.dst->MemberAccessor(param_sym, member_sym);
}; };
if (auto* location = ast::GetAttribute<ast::LocationAttribute>(member->attributes)) { if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) {
// Capture mapping from location to struct member. // Capture mapping from location to struct member.
LocationInfo info; LocationInfo info;
info.expr = member_expr; info.expr = member_expr;
info.type = ctx.src->Sem().Get(member)->Type();
location_info[location->value] = info; auto* sem = ctx.src->Sem().Get(member);
info.type = sem->Type();
TINT_ASSERT(Transform, sem->Location().has_value());
location_info[sem->Location().value()] = info;
has_locations = true; has_locations = true;
} else if (auto* builtin = } else if (auto* builtin =
ast::GetAttribute<ast::BuiltinAttribute>(member->attributes)) { ast::GetAttribute<ast::BuiltinAttribute>(member->attributes)) {

View File

@ -1856,7 +1856,7 @@ bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
return Switch( return Switch(
global, // global, //
[&](const ast::Var* var) { [&](const ast::Var* var) {
auto* sem = builder_.Sem().Get(global); auto* sem = builder_.Sem().Get<sem::GlobalVariable>(global);
switch (sem->StorageClass()) { switch (sem->StorageClass()) {
case ast::StorageClass::kUniform: case ast::StorageClass::kUniform:
return EmitUniformVariable(var, sem); return EmitUniformVariable(var, sem);
@ -2005,7 +2005,7 @@ bool GeneratorImpl::EmitWorkgroupVariable(const sem::Variable* var) {
return true; return true;
} }
bool GeneratorImpl::EmitIOVariable(const sem::Variable* var) { bool GeneratorImpl::EmitIOVariable(const sem::GlobalVariable* var) {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(decl->attributes)) { if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(decl->attributes)) {
@ -2018,7 +2018,7 @@ bool GeneratorImpl::EmitIOVariable(const sem::Variable* var) {
} }
auto out = line(); auto out = line();
EmitAttributes(out, decl->attributes); EmitAttributes(out, var, decl->attributes);
EmitInterpolationQualifiers(out, decl->attributes); EmitInterpolationQualifiers(out, decl->attributes);
auto name = builder_.Symbols().NameFor(decl->symbol); auto name = builder_.Symbols().NameFor(decl->symbol);
@ -2065,15 +2065,16 @@ void GeneratorImpl::EmitInterpolationQualifiers(
} }
bool GeneratorImpl::EmitAttributes(std::ostream& out, bool GeneratorImpl::EmitAttributes(std::ostream& out,
const sem::GlobalVariable* var,
utils::VectorRef<const ast::Attribute*> attributes) { utils::VectorRef<const ast::Attribute*> attributes) {
if (attributes.IsEmpty()) { if (attributes.IsEmpty()) {
return true; return true;
} }
bool first = true; bool first = true;
for (auto* attr : attributes) { for (auto* attr : attributes) {
if (auto* location = attr->As<ast::LocationAttribute>()) { if (attr->As<ast::LocationAttribute>()) {
out << (first ? "layout(" : ", "); out << (first ? "layout(" : ", ");
out << "location = " << std::to_string(location->value); out << "location = " << std::to_string(var->Location().value());
first = false; first = false;
} }
} }

View File

@ -324,7 +324,7 @@ class GeneratorImpl : public TextGenerator {
/// Handles emitting a global variable with the input or output storage class /// Handles emitting a global variable with the input or output storage class
/// @param var the global variable /// @param var the global variable
/// @returns true on success /// @returns true on success
bool EmitIOVariable(const sem::Variable* var); bool EmitIOVariable(const sem::GlobalVariable* var);
/// Handles emitting interpolation qualifiers /// Handles emitting interpolation qualifiers
/// @param out the output of the expression stream /// @param out the output of the expression stream
@ -333,9 +333,12 @@ class GeneratorImpl : public TextGenerator {
utils::VectorRef<const ast::Attribute*> attrs); utils::VectorRef<const ast::Attribute*> attrs);
/// Handles emitting attributes /// Handles emitting attributes
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param var the global variable semantics
/// @param attrs the attributes /// @param attrs the attributes
/// @returns true if the attributes were emitted /// @returns true if the attributes were emitted
bool EmitAttributes(std::ostream& out, utils::VectorRef<const ast::Attribute*> attrs); bool EmitAttributes(std::ostream& out,
const sem::GlobalVariable* var,
utils::VectorRef<const ast::Attribute*> attrs);
/// Handles emitting the entry point function /// Handles emitting the entry point function
/// @param func the entry point /// @param func the entry point
/// @returns true if the entry point function was emitted /// @returns true if the entry point function was emitted

View File

@ -128,7 +128,7 @@ TEST_F(GlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars)
// } // }
Func("frag_main", Func("frag_main",
utils::Vector{ utils::Vector{
Param("foo", ty.f32(), utils::Vector{Location(0)}), Param("foo", ty.f32(), utils::Vector{Location(0_a)}),
}, },
ty.f32(), ty.f32(),
utils::Vector{ utils::Vector{
@ -138,7 +138,7 @@ TEST_F(GlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars)
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(1), Location(1_a),
}); });
GeneratorImpl& gen = SanitizeAndBuild(); GeneratorImpl& gen = SanitizeAndBuild();
@ -218,8 +218,8 @@ TEST_F(GlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_SharedStruct_Di
"Interface", "Interface",
utils::Vector{ utils::Vector{
Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}),
Member("col1", ty.f32(), utils::Vector{Location(1)}), Member("col1", ty.f32(), utils::Vector{Location(1_a)}),
Member("col2", ty.f32(), utils::Vector{Location(2)}), Member("col2", ty.f32(), utils::Vector{Location(2_a)}),
}); });
Func("vert_main", utils::Empty, ty.Of(interface_struct), Func("vert_main", utils::Empty, ty.Of(interface_struct),

View File

@ -3947,23 +3947,24 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
std::string pre, post; std::string pre, post;
if (auto* decl = mem->Declaration()) { if (auto* decl = mem->Declaration()) {
for (auto* attr : decl->attributes) { for (auto* attr : decl->attributes) {
if (auto* location = attr->As<ast::LocationAttribute>()) { if (attr->Is<ast::LocationAttribute>()) {
auto& pipeline_stage_uses = str->PipelineStageUses(); auto& pipeline_stage_uses = str->PipelineStageUses();
if (pipeline_stage_uses.size() != 1) { if (pipeline_stage_uses.size() != 1) {
TINT_ICE(Writer, diagnostics_) << "invalid entry point IO struct uses"; TINT_ICE(Writer, diagnostics_) << "invalid entry point IO struct uses";
} }
auto loc = mem->Location().value();
if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) { if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) {
post += " : TEXCOORD" + std::to_string(location->value); post += " : TEXCOORD" + std::to_string(loc);
} else if (pipeline_stage_uses.count( } else if (pipeline_stage_uses.count(
sem::PipelineStageUsage::kVertexOutput)) { sem::PipelineStageUsage::kVertexOutput)) {
post += " : TEXCOORD" + std::to_string(location->value); post += " : TEXCOORD" + std::to_string(loc);
} else if (pipeline_stage_uses.count( } else if (pipeline_stage_uses.count(
sem::PipelineStageUsage::kFragmentInput)) { sem::PipelineStageUsage::kFragmentInput)) {
post += " : TEXCOORD" + std::to_string(location->value); post += " : TEXCOORD" + std::to_string(loc);
} else if (pipeline_stage_uses.count( } else if (pipeline_stage_uses.count(
sem::PipelineStageUsage::kFragmentOutput)) { sem::PipelineStageUsage::kFragmentOutput)) {
post += " : SV_Target" + std::to_string(location->value); post += " : SV_Target" + std::to_string(loc);
} else { } else {
TINT_ICE(Writer, diagnostics_) << "invalid use of location attribute"; TINT_ICE(Writer, diagnostics_) << "invalid use of location attribute";
} }

View File

@ -117,7 +117,7 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars)
// fn frag_main(@location(0) foo : f32) -> @location(1) f32 { // fn frag_main(@location(0) foo : f32) -> @location(1) f32 {
// return foo; // return foo;
// } // }
auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0)}); auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0_a)});
Func("frag_main", utils::Vector{foo_in}, ty.f32(), Func("frag_main", utils::Vector{foo_in}, ty.f32(),
utils::Vector{ utils::Vector{
Return("foo"), Return("foo"),
@ -126,7 +126,7 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_WithInOutVars)
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(1), Location(1_a),
}); });
GeneratorImpl& gen = SanitizeAndBuild(); GeneratorImpl& gen = SanitizeAndBuild();
@ -210,8 +210,8 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Attribute_EntryPoint_SharedStruct_Di
"Interface", "Interface",
utils::Vector{ utils::Vector{
Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}),
Member("col1", ty.f32(), utils::Vector{Location(1)}), Member("col1", ty.f32(), utils::Vector{Location(1_a)}),
Member("col2", ty.f32(), utils::Vector{Location(2)}), Member("col2", ty.f32(), utils::Vector{Location(2_a)}),
}); });
Func("vert_main", utils::Empty, ty.Of(interface_struct), Func("vert_main", utils::Empty, ty.Of(interface_struct),

View File

@ -2785,24 +2785,25 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
out << " [[" << name << "]]"; out << " [[" << name << "]]";
return true; return true;
}, },
[&](const ast::LocationAttribute* loc) { [&](const ast::LocationAttribute*) {
auto& pipeline_stage_uses = str->PipelineStageUses(); auto& pipeline_stage_uses = str->PipelineStageUses();
if (pipeline_stage_uses.size() != 1) { if (pipeline_stage_uses.size() != 1) {
TINT_ICE(Writer, diagnostics_) << "invalid entry point IO struct uses"; TINT_ICE(Writer, diagnostics_) << "invalid entry point IO struct uses";
return false; return false;
} }
uint32_t loc = mem->Location().value();
if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) { if (pipeline_stage_uses.count(sem::PipelineStageUsage::kVertexInput)) {
out << " [[attribute(" + std::to_string(loc->value) + ")]]"; out << " [[attribute(" + std::to_string(loc) + ")]]";
} else if (pipeline_stage_uses.count( } else if (pipeline_stage_uses.count(
sem::PipelineStageUsage::kVertexOutput)) { sem::PipelineStageUsage::kVertexOutput)) {
out << " [[user(locn" + std::to_string(loc->value) + ")]]"; out << " [[user(locn" + std::to_string(loc) + ")]]";
} else if (pipeline_stage_uses.count( } else if (pipeline_stage_uses.count(
sem::PipelineStageUsage::kFragmentInput)) { sem::PipelineStageUsage::kFragmentInput)) {
out << " [[user(locn" + std::to_string(loc->value) + ")]]"; out << " [[user(locn" + std::to_string(loc) + ")]]";
} else if (pipeline_stage_uses.count( } else if (pipeline_stage_uses.count(
sem::PipelineStageUsage::kFragmentOutput)) { sem::PipelineStageUsage::kFragmentOutput)) {
out << " [[color(" + std::to_string(loc->value) + ")]]"; out << " [[color(" + std::to_string(loc) + ")]]";
} else { } else {
TINT_ICE(Writer, diagnostics_) << "invalid use of location decoration"; TINT_ICE(Writer, diagnostics_) << "invalid use of location decoration";
return false; return false;

View File

@ -91,7 +91,7 @@ TEST_F(MslGeneratorImplTest, Emit_Attribute_EntryPoint_WithInOutVars) {
// fn frag_main(@location(0) foo : f32) -> @location(1) f32 { // fn frag_main(@location(0) foo : f32) -> @location(1) f32 {
// return foo; // return foo;
// } // }
auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0)}); auto* foo_in = Param("foo", ty.f32(), utils::Vector{Location(0_a)});
Func("frag_main", utils::Vector{foo_in}, ty.f32(), Func("frag_main", utils::Vector{foo_in}, ty.f32(),
utils::Vector{ utils::Vector{
Return("foo"), Return("foo"),
@ -100,7 +100,7 @@ TEST_F(MslGeneratorImplTest, Emit_Attribute_EntryPoint_WithInOutVars) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(1), Location(1_a),
}); });
GeneratorImpl& gen = SanitizeAndBuild(); GeneratorImpl& gen = SanitizeAndBuild();
@ -188,8 +188,8 @@ TEST_F(MslGeneratorImplTest, Emit_Attribute_EntryPoint_SharedStruct_DifferentSta
auto* interface_struct = Structure( auto* interface_struct = Structure(
"Interface", "Interface",
utils::Vector{ utils::Vector{
Member("col1", ty.f32(), utils::Vector{Location(1)}), Member("col1", ty.f32(), utils::Vector{Location(1_a)}),
Member("col2", ty.f32(), utils::Vector{Location(2)}), Member("col2", ty.f32(), utils::Vector{Location(2_a)}),
Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}),
}); });

View File

@ -884,9 +884,9 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* v) {
U32Operand(ConvertBuiltin(builtin->builtin, sem->StorageClass()))}); U32Operand(ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
return true; return true;
}, },
[&](const ast::LocationAttribute* location) { [&](const ast::LocationAttribute*) {
push_annot(spv::Op::OpDecorate, {Operand(var_id), U32Operand(SpvDecorationLocation), push_annot(spv::Op::OpDecorate, {Operand(var_id), U32Operand(SpvDecorationLocation),
Operand(location->value)}); Operand(sem->Location().value())});
return true; return true;
}, },
[&](const ast::InterpolateAttribute* interpolate) { [&](const ast::InterpolateAttribute* interpolate) {

View File

@ -48,7 +48,7 @@ TEST_F(BuilderTest, EntryPoint_Parameters) {
}); });
auto* loc1 = Param("loc1", ty.f32(), auto* loc1 = Param("loc1", ty.f32(),
utils::Vector{ utils::Vector{
Location(1u), Location(1_u),
}); });
auto* mul = Mul(Expr(MemberAccessor("coord", "x")), Expr("loc1")); auto* mul = Mul(Expr(MemberAccessor("coord", "x")), Expr("loc1"));
auto* col = Var("col", ty.f32(), mul); auto* col = Var("col", ty.f32(), mul);
@ -120,7 +120,7 @@ TEST_F(BuilderTest, EntryPoint_ReturnValue) {
// } // }
auto* loc_in = Param("loc_in", ty.u32(), auto* loc_in = Param("loc_in", ty.u32(),
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
Flat(), Flat(),
}); });
auto* cond = auto* cond =
@ -134,7 +134,7 @@ TEST_F(BuilderTest, EntryPoint_ReturnValue) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(0), Location(0_a),
}); });
spirv::Builder& b = SanitizeAndBuild(); spirv::Builder& b = SanitizeAndBuild();
@ -211,7 +211,7 @@ TEST_F(BuilderTest, EntryPoint_SharedStruct) {
auto* interface = Structure( auto* interface = Structure(
"Interface", "Interface",
utils::Vector{ utils::Vector{
Member("value", ty.f32(), utils::Vector{Location(1u)}), Member("value", ty.f32(), utils::Vector{Location(1_u)}),
Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}), Member("pos", ty.vec4<f32>(), utils::Vector{Builtin(ast::BuiltinValue::kPosition)}),
}); });

View File

@ -756,7 +756,11 @@ bool GeneratorImpl::EmitAttributes(std::ostream& out,
return true; return true;
}, },
[&](const ast::LocationAttribute* location) { [&](const ast::LocationAttribute* location) {
out << "location(" << location->value << ")"; out << "location(";
if (!EmitExpression(out, location->value)) {
return false;
}
out << ")";
return true; return true;
}, },
[&](const ast::BuiltinAttribute* builtin) { [&](const ast::BuiltinAttribute* builtin) {

View File

@ -116,7 +116,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_EntryPoint_Parameters) {
}); });
auto* loc1 = Param("loc1", ty.f32(), auto* loc1 = Param("loc1", ty.f32(),
utils::Vector{ utils::Vector{
Location(1u), Location(1_a),
}); });
auto* func = Func("frag_main", utils::Vector{coord, loc1}, ty.void_(), utils::Empty, auto* func = Func("frag_main", utils::Vector{coord, loc1}, ty.void_(), utils::Empty,
utils::Vector{ utils::Vector{
@ -143,7 +143,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_EntryPoint_ReturnValue) {
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}, },
utils::Vector{ utils::Vector{
Location(1u), Location(1_a),
}); });
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -274,7 +274,7 @@ TEST_F(WgslGeneratorImplTest, EmitType_Struct_WithEntryPointAttributes) {
auto* s = Structure( auto* s = Structure(
"S", utils::Vector{ "S", utils::Vector{
Member("a", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kVertexIndex)}), Member("a", ty.u32(), utils::Vector{Builtin(ast::BuiltinValue::kVertexIndex)}),
Member("b", ty.f32(), utils::Vector{Location(2u)}), Member("b", ty.f32(), utils::Vector{Location(2_a)}),
}); });
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();