AST fuzzer: change binary operator

A mutation and mutation finder that changes the operator in a binary
expression to something type-compatible.

Fixes: tint:1085
Change-Id: I2e35d3cdfdbcc52d4dc5981b187da217fc48e462
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/84640
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Alastair Donaldson <afdx@google.com>
Auto-Submit: Alastair Donaldson <afdx@google.com>
This commit is contained in:
Alastair Donaldson 2022-03-25 12:31:35 +00:00 committed by Tint LUCI CQ
parent 1006b06c7d
commit 444e051faa
16 changed files with 1580 additions and 9 deletions

View File

@ -46,8 +46,12 @@ if (build_with_chromium) {
"mutation.h", "mutation.h",
"mutation_finder.cc", "mutation_finder.cc",
"mutation_finder.h", "mutation_finder.h",
"mutation_finders/change_binary_operators.cc",
"mutation_finders/change_binary_operators.h",
"mutation_finders/replace_identifiers.cc", "mutation_finders/replace_identifiers.cc",
"mutation_finders/replace_identifiers.h", "mutation_finders/replace_identifiers.h",
"mutations/change_binary_operator.cc",
"mutations/change_binary_operator.h",
"mutations/replace_identifier.cc", "mutations/replace_identifier.cc",
"mutations/replace_identifier.h", "mutations/replace_identifier.h",
"mutator.cc", "mutator.cc",

View File

@ -40,7 +40,9 @@ set(LIBTINT_AST_FUZZER_SOURCES
../random_generator_engine.h ../random_generator_engine.h
mutation.h mutation.h
mutation_finder.h mutation_finder.h
mutation_finders/change_binary_operators.h
mutation_finders/replace_identifiers.h mutation_finders/replace_identifiers.h
mutations/change_binary_operator.h
mutations/replace_identifier.h mutations/replace_identifier.h
mutator.h mutator.h
node_id_map.h node_id_map.h
@ -55,7 +57,9 @@ set(LIBTINT_AST_FUZZER_SOURCES ${LIBTINT_AST_FUZZER_SOURCES}
../random_generator_engine.cc ../random_generator_engine.cc
mutation.cc mutation.cc
mutation_finder.cc mutation_finder.cc
mutation_finders/change_binary_operators.cc
mutation_finders/replace_identifiers.cc mutation_finders/replace_identifiers.cc
mutations/change_binary_operator.cc
mutations/replace_identifier.cc mutations/replace_identifier.cc
mutator.cc mutator.cc
node_id_map.cc node_id_map.cc
@ -92,6 +96,7 @@ add_tint_ast_fuzzer(tint_ast_wgsl_writer_fuzzer)
# Add tests. # Add tests.
if (${TINT_BUILD_TESTS}) if (${TINT_BUILD_TESTS})
set(TEST_SOURCES set(TEST_SOURCES
mutations/change_binary_operator_test.cc
mutations/replace_identifier_test.cc) mutations/replace_identifier_test.cc)
add_executable(tint_ast_fuzzer_unittests ${TEST_SOURCES}) add_executable(tint_ast_fuzzer_unittests ${TEST_SOURCES})

View File

@ -16,6 +16,7 @@
#include <cassert> #include <cassert>
#include "src/tint/fuzzers/tint_ast_fuzzer/mutations/change_binary_operator.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/mutations/replace_identifier.h" #include "src/tint/fuzzers/tint_ast_fuzzer/mutations/replace_identifier.h"
namespace tint { namespace tint {
@ -30,6 +31,9 @@ std::unique_ptr<Mutation> Mutation::FromMessage(
case protobufs::Mutation::kReplaceIdentifier: case protobufs::Mutation::kReplaceIdentifier:
return std::make_unique<MutationReplaceIdentifier>( return std::make_unique<MutationReplaceIdentifier>(
message.replace_identifier()); message.replace_identifier());
case protobufs::Mutation::kChangeBinaryOperator:
return std::make_unique<MutationChangeBinaryOperator>(
message.change_binary_operator());
case protobufs::Mutation::MUTATION_NOT_SET: case protobufs::Mutation::MUTATION_NOT_SET:
assert(false && "Mutation is not set"); assert(false && "Mutation is not set");
break; break;

View File

@ -0,0 +1,92 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/change_binary_operators.h"
#include <memory>
#include <vector>
#include "src/tint/ast/binary_expression.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/mutations/change_binary_operator.h"
namespace tint {
namespace fuzzers {
namespace ast_fuzzer {
MutationList MutationFinderChangeBinaryOperators::FindMutations(
const tint::Program& program,
NodeIdMap* node_id_map,
ProbabilityContext* probability_context) const {
MutationList result;
// Go through each binary expression in the AST and add a mutation that
// replaces its operator with some other type-compatible operator.
const std::vector<ast::BinaryOp> all_binary_operators = {
ast::BinaryOp::kAnd,
ast::BinaryOp::kOr,
ast::BinaryOp::kXor,
ast::BinaryOp::kLogicalAnd,
ast::BinaryOp::kLogicalOr,
ast::BinaryOp::kEqual,
ast::BinaryOp::kNotEqual,
ast::BinaryOp::kLessThan,
ast::BinaryOp::kGreaterThan,
ast::BinaryOp::kLessThanEqual,
ast::BinaryOp::kGreaterThanEqual,
ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight,
ast::BinaryOp::kAdd,
ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply,
ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo};
for (const auto* node : program.ASTNodes().Objects()) {
const auto* binary_expr = As<ast::BinaryExpression>(node);
if (!binary_expr) {
continue;
}
// Get vector of all operators this could be replaced with.
std::vector<ast::BinaryOp> allowed_replacements;
for (auto candidate_op : all_binary_operators) {
if (MutationChangeBinaryOperator::CanReplaceBinaryOperator(
program, *binary_expr, candidate_op)) {
allowed_replacements.push_back(candidate_op);
}
}
if (!allowed_replacements.empty()) {
// Choose an available replacement operator at random.
const ast::BinaryOp replacement =
allowed_replacements[probability_context->GetRandomIndex(
allowed_replacements)];
// Add a mutation according to the chosen replacement.
result.push_back(std::make_unique<MutationChangeBinaryOperator>(
node_id_map->GetId(binary_expr), replacement));
}
}
return result;
}
uint32_t MutationFinderChangeBinaryOperators::GetChanceOfApplyingMutation(
ProbabilityContext* probability_context) const {
return probability_context->GetChanceOfChangingBinaryOperators();
}
} // namespace ast_fuzzer
} // namespace fuzzers
} // namespace tint

View File

@ -0,0 +1,42 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_FUZZERS_TINT_AST_FUZZER_MUTATION_FINDERS_CHANGE_BINARY_OPERATORS_H_
#define SRC_TINT_FUZZERS_TINT_AST_FUZZER_MUTATION_FINDERS_CHANGE_BINARY_OPERATORS_H_
#include "src/tint/fuzzers/tint_ast_fuzzer/mutation_finder.h"
namespace tint {
namespace fuzzers {
namespace ast_fuzzer {
/// Looks for opportunities to apply `MutationChangeBinaryOperator`.
///
/// Concretely, for each binary expression in the module, tries to replace it
/// with a different, type-compatible operator.
class MutationFinderChangeBinaryOperators : public MutationFinder {
public:
MutationList FindMutations(
const tint::Program& program,
NodeIdMap* node_id_map,
ProbabilityContext* probability_context) const override;
uint32_t GetChanceOfApplyingMutation(
ProbabilityContext* probability_context) const override;
};
} // namespace ast_fuzzer
} // namespace fuzzers
} // namespace tint
#endif // SRC_TINT_FUZZERS_TINT_AST_FUZZER_MUTATION_FINDERS_CHANGE_BINARY_OPERATORS_H_

View File

@ -0,0 +1,528 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/fuzzers/tint_ast_fuzzer/mutations/change_binary_operator.h"
#include <utility>
#include "src/tint/sem/reference_type.h"
namespace tint {
namespace fuzzers {
namespace ast_fuzzer {
namespace {
bool IsSuitableForShift(const sem::Type* lhs_type, const sem::Type* rhs_type) {
// `a << b` requires b to be an unsigned scalar or vector, and `a` to be an
// integer scalar or vector with the same width as `b`. Similar for `a >> b`.
if (rhs_type->is_unsigned_integer_scalar()) {
return lhs_type->is_integer_scalar();
}
if (rhs_type->is_unsigned_integer_vector()) {
return lhs_type->is_unsigned_integer_vector();
}
return false;
}
bool CanReplaceAddSubtractWith(const sem::Type* lhs_type,
const sem::Type* rhs_type,
ast::BinaryOp new_operator) {
// The program is assumed to be well-typed, so this method determines when
// 'new_operator' can be used as a type-preserving replacement in an '+' or
// '-' expression.
switch (new_operator) {
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
// '+' and '-' are fully type compatible.
return true;
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
case ast::BinaryOp::kXor:
// These operators do not have a mixed vector-scalar form, and only work
// on integer types.
return lhs_type == rhs_type && lhs_type->is_integer_scalar_or_vector();
case ast::BinaryOp::kMultiply:
// '+' and '*' are largely type-compatible, but for matrices they are only
// type-compatible if the matrices are square.
return !lhs_type->is_float_matrix() || lhs_type->is_square_float_matrix();
case ast::BinaryOp::kDivide:
// '/' is not defined for matrices.
return lhs_type->is_numeric_scalar_or_vector() &&
rhs_type->is_numeric_scalar_or_vector();
case ast::BinaryOp::kModulo:
// TODO(https://crbug.com/tint/1370): once fixed, the rules should be the
// same as for divide.
if (lhs_type->is_float_vector() || rhs_type->is_float_vector()) {
return lhs_type == rhs_type;
}
return !lhs_type->is_float_matrix() && !rhs_type->is_float_matrix();
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
return IsSuitableForShift(lhs_type, rhs_type);
default:
return false;
}
}
bool CanReplaceMultiplyWith(const sem::Type* lhs_type,
const sem::Type* rhs_type,
ast::BinaryOp new_operator) {
// The program is assumed to be well-typed, so this method determines when
// 'new_operator' can be used as a type-preserving replacement in a '*'
// expression.
switch (new_operator) {
case ast::BinaryOp::kMultiply:
return true;
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
// '*' is type-compatible with '+' and '-' for square matrices, and for
// numeric scalars/vectors.
if (lhs_type->is_square_float_matrix() &&
rhs_type->is_square_float_matrix()) {
return true;
}
return lhs_type->is_numeric_scalar_or_vector() &&
rhs_type->is_numeric_scalar_or_vector();
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
case ast::BinaryOp::kXor:
// These operators require homogeneous integer types.
return lhs_type == rhs_type && lhs_type->is_integer_scalar_or_vector();
case ast::BinaryOp::kDivide:
// '/' is not defined for matrices.
return lhs_type->is_numeric_scalar_or_vector() &&
rhs_type->is_numeric_scalar_or_vector();
case ast::BinaryOp::kModulo:
// TODO(https://crbug.com/tint/1370): once fixed, this should be the same
// as for divide
if (lhs_type->is_float_vector() || rhs_type->is_float_vector()) {
return lhs_type == rhs_type;
}
return !lhs_type->is_float_matrix() && !rhs_type->is_float_matrix();
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
return IsSuitableForShift(lhs_type, rhs_type);
default:
return false;
}
}
bool CanReplaceDivideWith(const sem::Type* lhs_type,
const sem::Type* rhs_type,
ast::BinaryOp new_operator) {
// The program is assumed to be well-typed, so this method determines when
// 'new_operator' can be used as a type-preserving replacement in a '/'
// expression.
switch (new_operator) {
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
case ast::BinaryOp::kMultiply:
case ast::BinaryOp::kDivide:
// These operators work in all contexts where '/' works.
return true;
case ast::BinaryOp::kModulo:
// TODO(https://crbug.com/tint/1370): this special case should not be
// required; modulo and divide should work in the same contexts.
return lhs_type->is_integer_scalar_or_vector() || lhs_type == rhs_type;
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
case ast::BinaryOp::kXor:
// These operators require homogeneous integer types.
return lhs_type == rhs_type && lhs_type->is_integer_scalar_or_vector();
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
return IsSuitableForShift(lhs_type, rhs_type);
default:
return false;
}
}
// TODO(https://crbug.com/tint/1370): once fixed, this method will be removed
// and the same method will be used to check Divide and Modulo.
bool CanReplaceModuloWith(const sem::Type* lhs_type,
const sem::Type* rhs_type,
ast::BinaryOp new_operator) {
switch (new_operator) {
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
case ast::BinaryOp::kMultiply:
case ast::BinaryOp::kDivide:
case ast::BinaryOp::kModulo:
return true;
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
case ast::BinaryOp::kXor:
return lhs_type == rhs_type && lhs_type->is_integer_scalar_or_vector();
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
return IsSuitableForShift(lhs_type, rhs_type);
default:
return false;
}
}
bool CanReplaceLogicalAndLogicalOrWith(ast::BinaryOp new_operator) {
switch (new_operator) {
case ast::BinaryOp::kLogicalAnd:
case ast::BinaryOp::kLogicalOr:
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual:
// These operators all work whenever '&&' and '||' work.
return true;
default:
return false;
}
}
bool CanReplaceAndOrWith(const sem::Type* lhs_type,
const sem::Type* rhs_type,
ast::BinaryOp new_operator) {
switch (new_operator) {
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
// '&' and '|' work in all the same contexts.
return true;
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
case ast::BinaryOp::kMultiply:
case ast::BinaryOp::kDivide:
case ast::BinaryOp::kModulo:
case ast::BinaryOp::kXor:
// '&' and '|' can be applied to booleans. In all other contexts,
// integer numeric operators work.
return !lhs_type->is_bool_scalar_or_vector();
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
return IsSuitableForShift(lhs_type, rhs_type);
case ast::BinaryOp::kLogicalAnd:
case ast::BinaryOp::kLogicalOr:
// '&' and '|' can be applied to booleans, and for boolean scalar
// scalar contexts, their logical counterparts work.
return lhs_type->Is<sem::Bool>();
case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual:
// '&' and '|' can be applied to booleans, and in these contexts equality
// comparison operators also work.
return lhs_type->is_bool_scalar_or_vector();
default:
return false;
}
}
bool CanReplaceXorWith(const sem::Type* lhs_type,
const sem::Type* rhs_type,
ast::BinaryOp new_operator) {
switch (new_operator) {
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
case ast::BinaryOp::kMultiply:
case ast::BinaryOp::kDivide:
case ast::BinaryOp::kModulo:
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
case ast::BinaryOp::kXor:
// '^' only works on integer types, and in any such context, all other
// integer operators also work.
return true;
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
return IsSuitableForShift(lhs_type, rhs_type);
default:
return false;
}
}
bool CanReplaceShiftLeftShiftRightWith(const sem::Type* lhs_type,
const sem::Type* rhs_type,
ast::BinaryOp new_operator) {
switch (new_operator) {
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
// These operators are type-compatible.
return true;
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
case ast::BinaryOp::kMultiply:
case ast::BinaryOp::kDivide:
case ast::BinaryOp::kModulo:
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
case ast::BinaryOp::kXor:
// Shift operators allow mixing of signed and unsigned arguments, but in
// the case where the arguments are homogeneous, they are type-compatible
// with other numeric operators.
return lhs_type == rhs_type;
default:
return false;
}
}
bool CanReplaceEqualNotEqualWith(const sem::Type* lhs_type,
ast::BinaryOp new_operator) {
switch (new_operator) {
case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual:
// These operators are type-compatible.
return true;
case ast::BinaryOp::kLessThan:
case ast::BinaryOp::kLessThanEqual:
case ast::BinaryOp::kGreaterThan:
case ast::BinaryOp::kGreaterThanEqual:
// An equality comparison between numeric types can be changed to an
// ordered comparison.
return lhs_type->is_numeric_scalar_or_vector();
case ast::BinaryOp::kLogicalAnd:
case ast::BinaryOp::kLogicalOr:
// An equality comparison between boolean scalars can be turned into a
// logical operation.
return lhs_type->Is<sem::Bool>();
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
// An equality comparison between boolean scalars or vectors can be turned
// into a component-wise non-short-circuit logical operation.
return lhs_type->is_bool_scalar_or_vector();
default:
return false;
}
}
bool CanReplaceLessThanLessThanEqualGreaterThanGreaterThanEqualWith(
ast::BinaryOp new_operator) {
switch (new_operator) {
case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual:
case ast::BinaryOp::kLessThan:
case ast::BinaryOp::kLessThanEqual:
case ast::BinaryOp::kGreaterThan:
case ast::BinaryOp::kGreaterThanEqual:
// Ordered comparison operators can be interchanged, and equality
// operators can be used in their place.
return true;
default:
return false;
}
}
} // namespace
MutationChangeBinaryOperator::MutationChangeBinaryOperator(
protobufs::MutationChangeBinaryOperator message)
: message_(std::move(message)) {}
MutationChangeBinaryOperator::MutationChangeBinaryOperator(
uint32_t binary_expr_id,
ast::BinaryOp new_operator) {
message_.set_binary_expr_id(binary_expr_id);
message_.set_new_operator(static_cast<uint32_t>(new_operator));
}
bool MutationChangeBinaryOperator::CanReplaceBinaryOperator(
const Program& program,
const ast::BinaryExpression& binary_expr,
ast::BinaryOp new_operator) {
if (new_operator == binary_expr.op) {
// An operator should not be replaced with itself, as this would be a no-op.
return false;
}
// Get the types of the operators.
const auto* lhs_type = program.Sem().Get(binary_expr.lhs)->Type();
const auto* rhs_type = program.Sem().Get(binary_expr.rhs)->Type();
// If these are reference types, unwrap them to get the pointee type.
const sem::Type* lhs_basic_type =
lhs_type->Is<sem::Reference>()
? lhs_type->As<sem::Reference>()->StoreType()
: lhs_type;
const sem::Type* rhs_basic_type =
rhs_type->Is<sem::Reference>()
? rhs_type->As<sem::Reference>()->StoreType()
: rhs_type;
switch (binary_expr.op) {
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
return CanReplaceAddSubtractWith(lhs_basic_type, rhs_basic_type,
new_operator);
case ast::BinaryOp::kMultiply:
return CanReplaceMultiplyWith(lhs_basic_type, rhs_basic_type,
new_operator);
case ast::BinaryOp::kDivide:
return CanReplaceDivideWith(lhs_basic_type, rhs_basic_type, new_operator);
case ast::BinaryOp::kModulo:
return CanReplaceModuloWith(lhs_basic_type, rhs_basic_type, new_operator);
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
return CanReplaceAndOrWith(lhs_basic_type, rhs_basic_type, new_operator);
case ast::BinaryOp::kXor:
return CanReplaceXorWith(lhs_basic_type, rhs_basic_type, new_operator);
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
return CanReplaceShiftLeftShiftRightWith(lhs_basic_type, rhs_basic_type,
new_operator);
case ast::BinaryOp::kLogicalAnd:
case ast::BinaryOp::kLogicalOr:
return CanReplaceLogicalAndLogicalOrWith(new_operator);
case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual:
return CanReplaceEqualNotEqualWith(lhs_basic_type, new_operator);
case ast::BinaryOp::kLessThan:
case ast::BinaryOp::kLessThanEqual:
case ast::BinaryOp::kGreaterThan:
case ast::BinaryOp::kGreaterThanEqual:
case ast::BinaryOp::kNone:
return CanReplaceLessThanLessThanEqualGreaterThanGreaterThanEqualWith(
new_operator);
assert(false && "Unreachable");
return false;
}
}
bool MutationChangeBinaryOperator::IsApplicable(
const Program& program,
const NodeIdMap& node_id_map) const {
const auto* binary_expr_node =
As<ast::BinaryExpression>(node_id_map.GetNode(message_.binary_expr_id()));
if (binary_expr_node == nullptr) {
// Either the id does not exist, or does not correspond to a binary
// expression.
return false;
}
// Check whether the replacement is acceptable.
const auto new_operator = static_cast<ast::BinaryOp>(message_.new_operator());
return CanReplaceBinaryOperator(program, *binary_expr_node, new_operator);
}
void MutationChangeBinaryOperator::Apply(const NodeIdMap& node_id_map,
CloneContext* clone_context,
NodeIdMap* new_node_id_map) const {
// Get the node whose operator is to be replaced.
const auto* binary_expr_node =
As<ast::BinaryExpression>(node_id_map.GetNode(message_.binary_expr_id()));
// Clone the binary expression, with the appropriate new operator.
const ast::BinaryExpression* cloned_replacement;
switch (static_cast<ast::BinaryOp>(message_.new_operator())) {
case ast::BinaryOp::kAnd:
cloned_replacement =
clone_context->dst->And(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kOr:
cloned_replacement =
clone_context->dst->Or(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kXor:
cloned_replacement =
clone_context->dst->Xor(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kLogicalAnd:
cloned_replacement = clone_context->dst->LogicalAnd(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kLogicalOr:
cloned_replacement = clone_context->dst->LogicalOr(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kEqual:
cloned_replacement = clone_context->dst->Equal(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kNotEqual:
cloned_replacement = clone_context->dst->NotEqual(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kLessThan:
cloned_replacement = clone_context->dst->LessThan(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kGreaterThan:
cloned_replacement = clone_context->dst->GreaterThan(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kLessThanEqual:
cloned_replacement = clone_context->dst->LessThanEqual(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kGreaterThanEqual:
cloned_replacement = clone_context->dst->GreaterThanEqual(
clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kShiftLeft:
cloned_replacement =
clone_context->dst->Shl(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kShiftRight:
cloned_replacement =
clone_context->dst->Shr(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kAdd:
cloned_replacement =
clone_context->dst->Add(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kSubtract:
cloned_replacement =
clone_context->dst->Sub(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kMultiply:
cloned_replacement =
clone_context->dst->Mul(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kDivide:
cloned_replacement =
clone_context->dst->Div(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kModulo:
cloned_replacement =
clone_context->dst->Mod(clone_context->Clone(binary_expr_node->lhs),
clone_context->Clone(binary_expr_node->rhs));
break;
case ast::BinaryOp::kNone:
cloned_replacement = nullptr;
assert(false && "Unreachable");
}
// Set things up so that the original binary expression will be replaced with
// its clone, and update the id mapping.
clone_context->Replace(binary_expr_node, cloned_replacement);
new_node_id_map->Add(cloned_replacement, message_.binary_expr_id());
}
protobufs::Mutation MutationChangeBinaryOperator::ToMessage() const {
protobufs::Mutation mutation;
*mutation.mutable_change_binary_operator() = message_;
return mutation;
}
} // namespace ast_fuzzer
} // namespace fuzzers
} // namespace tint

View File

@ -0,0 +1,85 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_FUZZERS_TINT_AST_FUZZER_MUTATIONS_CHANGE_BINARY_OPERATOR_H_
#define SRC_TINT_FUZZERS_TINT_AST_FUZZER_MUTATIONS_CHANGE_BINARY_OPERATOR_H_
#include "src/tint/fuzzers/tint_ast_fuzzer/mutation.h"
#include "src/tint/ast/binary_expression.h"
#include "src/tint/program.h"
#include "src/tint/sem/variable.h"
namespace tint {
namespace fuzzers {
namespace ast_fuzzer {
/// @see MutationChangeBinaryOperator::Apply
class MutationChangeBinaryOperator : public Mutation {
public:
/// @brief Constructs an instance of this mutation from a protobuf message.
/// @param message - protobuf message
explicit MutationChangeBinaryOperator(
protobufs::MutationChangeBinaryOperator message);
/// @brief Constructor.
/// @param binary_expr_id - the id of a binary expression.
/// @param new_operator - a new binary operator to replace the one used in the
/// expression.
MutationChangeBinaryOperator(uint32_t binary_expr_id,
ast::BinaryOp new_operator);
/// @copybrief Mutation::IsApplicable
///
/// The mutation is applicable iff:
/// - `binary_expr_id` is a valid id of an `ast::BinaryExpression`.
/// - `new_operator` is type-compatible with the arguments of the binary
/// expression.
///
/// @copydetails Mutation::IsApplicable
bool IsApplicable(const tint::Program& program,
const NodeIdMap& node_id_map) const override;
/// @copybrief Mutation::Apply
///
/// Replaces binary operator in the binary expression corresponding to
/// `binary_expr_id` with `new_operator`.
///
/// @copydetails Mutation::Apply
void Apply(const NodeIdMap& node_id_map,
tint::CloneContext* clone_context,
NodeIdMap* new_node_id_map) const override;
protobufs::Mutation ToMessage() const override;
/// @brief Determines whether replacing the operator of a binary expression
/// with another operator would preserve well-typedness.
/// @param program - the program that owns the binary expression.
/// @param binary_expr - the binary expression being considered for mutation.
/// @param new_operator - a new binary operator to be checked as a candidate
/// replacement for the binary expression's operator.
/// @return `true` if and only if the replacement would be well-typed.
static bool CanReplaceBinaryOperator(const Program& program,
const ast::BinaryExpression& binary_expr,
ast::BinaryOp new_operator);
private:
protobufs::MutationChangeBinaryOperator message_;
};
} // namespace ast_fuzzer
} // namespace fuzzers
} // namespace tint
#endif // SRC_TINT_FUZZERS_TINT_AST_FUZZER_MUTATIONS_CHANGE_BINARY_OPERATOR_H_

View File

@ -0,0 +1,749 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/fuzzers/tint_ast_fuzzer/mutations/change_binary_operator.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "gtest/gtest.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/mutator.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/node_id_map.h"
#include "src/tint/program_builder.h"
#include "src/tint/reader/wgsl/parser.h"
#include "src/tint/writer/wgsl/generator.h"
namespace tint {
namespace fuzzers {
namespace ast_fuzzer {
namespace {
std::string OpToString(ast::BinaryOp op) {
switch (op) {
case ast::BinaryOp::kNone:
assert(false && "Unreachable");
return "";
case ast::BinaryOp::kAnd:
return "&";
case ast::BinaryOp::kOr:
return "|";
case ast::BinaryOp::kXor:
return "^";
case ast::BinaryOp::kLogicalAnd:
return "&&";
case ast::BinaryOp::kLogicalOr:
return "||";
case ast::BinaryOp::kEqual:
return "==";
case ast::BinaryOp::kNotEqual:
return "!=";
case ast::BinaryOp::kLessThan:
return "<";
case ast::BinaryOp::kGreaterThan:
return ">";
case ast::BinaryOp::kLessThanEqual:
return "<=";
case ast::BinaryOp::kGreaterThanEqual:
return ">=";
case ast::BinaryOp::kShiftLeft:
return "<<";
case ast::BinaryOp::kShiftRight:
return ">>";
case ast::BinaryOp::kAdd:
return "+";
case ast::BinaryOp::kSubtract:
return "-";
case ast::BinaryOp::kMultiply:
return "*";
case ast::BinaryOp::kDivide:
return "/";
case ast::BinaryOp::kModulo:
return "%";
}
}
TEST(ChangeBinaryOperatorTest, NotApplicable_Simple) {
std::string content = R"(
fn main() {
let a : i32 = 1 + 2;
}
)";
Source::File file("test.wgsl", content);
auto program = reader::wgsl::Parse(&file);
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
NodeIdMap node_id_map(program);
const auto& main_fn_stmts = program.AST().Functions()[0]->body->statements;
const auto* a_var =
main_fn_stmts[0]->As<ast::VariableDeclStatement>()->variable;
ASSERT_NE(a_var, nullptr);
auto a_var_id = node_id_map.GetId(a_var);
const auto* sum_expr = a_var->constructor->As<ast::BinaryExpression>();
ASSERT_NE(sum_expr, nullptr);
auto sum_expr_id = node_id_map.GetId(sum_expr);
ASSERT_NE(sum_expr_id, 0);
// binary_expr_id is invalid.
EXPECT_FALSE(MutationChangeBinaryOperator(0, ast::BinaryOp::kSubtract)
.IsApplicable(program, node_id_map));
// binary_expr_id is not a binary expression.
EXPECT_FALSE(MutationChangeBinaryOperator(a_var_id, ast::BinaryOp::kSubtract)
.IsApplicable(program, node_id_map));
// new_operator is applicable to the argument types.
EXPECT_FALSE(MutationChangeBinaryOperator(0, ast::BinaryOp::kLogicalAnd)
.IsApplicable(program, node_id_map));
// new_operator does not have the right result type.
EXPECT_FALSE(MutationChangeBinaryOperator(0, ast::BinaryOp::kLessThan)
.IsApplicable(program, node_id_map));
}
TEST(ChangeBinaryOperatorTest, Applicable_Simple) {
std::string shader = R"(fn main() {
let a : i32 = (1 + 2);
}
)";
Source::File file("test.wgsl", shader);
auto program = reader::wgsl::Parse(&file);
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
NodeIdMap node_id_map(program);
const auto& main_fn_stmts = program.AST().Functions()[0]->body->statements;
const auto* a_var =
main_fn_stmts[0]->As<ast::VariableDeclStatement>()->variable;
ASSERT_NE(a_var, nullptr);
const auto* sum_expr = a_var->constructor->As<ast::BinaryExpression>();
ASSERT_NE(sum_expr, nullptr);
auto sum_expr_id = node_id_map.GetId(sum_expr);
ASSERT_NE(sum_expr_id, 0);
ASSERT_TRUE(MaybeApplyMutation(
program,
MutationChangeBinaryOperator(sum_expr_id, ast::BinaryOp::kSubtract),
node_id_map, &program, &node_id_map, nullptr));
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
writer::wgsl::Options options;
auto result = writer::wgsl::Generate(&program, options);
ASSERT_TRUE(result.success) << result.error;
std::string expected_shader = R"(fn main() {
let a : i32 = (1 - 2);
}
)";
ASSERT_EQ(expected_shader, result.wgsl);
}
void CheckMutations(
const std::string& lhs_type,
const std::string& rhs_type,
const std::string& result_type,
ast::BinaryOp original_operator,
const std::unordered_set<ast::BinaryOp>& allowed_replacement_operators) {
std::stringstream shader;
shader << "fn foo(a : " << lhs_type << ", b : " << rhs_type + ") {\n"
<< " let r : " << result_type
<< " = (a " + OpToString(original_operator) << " b);\n}\n";
const std::vector<ast::BinaryOp> all_operators = {
ast::BinaryOp::kAnd,
ast::BinaryOp::kOr,
ast::BinaryOp::kXor,
ast::BinaryOp::kLogicalAnd,
ast::BinaryOp::kLogicalOr,
ast::BinaryOp::kEqual,
ast::BinaryOp::kNotEqual,
ast::BinaryOp::kLessThan,
ast::BinaryOp::kGreaterThan,
ast::BinaryOp::kLessThanEqual,
ast::BinaryOp::kGreaterThanEqual,
ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight,
ast::BinaryOp::kAdd,
ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply,
ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo};
for (auto new_operator : all_operators) {
Source::File file("test.wgsl", shader.str());
auto program = reader::wgsl::Parse(&file);
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
NodeIdMap node_id_map(program);
const auto& stmts = program.AST().Functions()[0]->body->statements;
const auto* r_var = stmts[0]->As<ast::VariableDeclStatement>()->variable;
ASSERT_NE(r_var, nullptr);
const auto* binary_expr = r_var->constructor->As<ast::BinaryExpression>();
ASSERT_NE(binary_expr, nullptr);
auto binary_expr_id = node_id_map.GetId(binary_expr);
ASSERT_NE(binary_expr_id, 0);
MutationChangeBinaryOperator mutation(binary_expr_id, new_operator);
std::stringstream expected_shader;
expected_shader << "fn foo(a : " << lhs_type << ", b : " << rhs_type
<< ") {\n"
<< " let r : " << result_type << " = (a "
<< OpToString(new_operator) << " b);\n}\n";
if (allowed_replacement_operators.count(new_operator) == 0) {
ASSERT_FALSE(mutation.IsApplicable(program, node_id_map));
if (new_operator != binary_expr->op) {
Source::File invalid_file("test.wgsl", expected_shader.str());
auto invalid_program = reader::wgsl::Parse(&invalid_file);
ASSERT_FALSE(invalid_program.IsValid()) << program.Diagnostics().str();
}
} else {
ASSERT_TRUE(MaybeApplyMutation(program, mutation, node_id_map, &program,
&node_id_map, nullptr));
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
writer::wgsl::Options options;
auto result = writer::wgsl::Generate(&program, options);
ASSERT_TRUE(result.success) << result.error;
ASSERT_EQ(expected_shader.str(), result.wgsl);
}
}
}
TEST(ChangeBinaryOperatorTest, AddSubtract) {
for (auto op : {ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract}) {
const ast::BinaryOp other_op = op == ast::BinaryOp::kAdd
? ast::BinaryOp::kSubtract
: ast::BinaryOp::kAdd;
for (std::string type : {"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
CheckMutations(
type, type, type, op,
{other_op, ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo, ast::BinaryOp::kAnd, ast::BinaryOp::kOr,
ast::BinaryOp::kXor});
}
for (std::string type : {"u32", "vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
CheckMutations(
type, type, type, op,
{other_op, ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo, ast::BinaryOp::kAnd, ast::BinaryOp::kOr,
ast::BinaryOp::kXor, ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight});
}
for (std::string type : {"f32", "vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
CheckMutations(type, type, type, op,
{other_op, ast::BinaryOp::kMultiply,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
std::string scalar_type = "i32";
CheckMutations(vector_type, scalar_type, vector_type, op,
{other_op, ast::BinaryOp::kMultiply,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
CheckMutations(scalar_type, vector_type, vector_type, op,
{other_op, ast::BinaryOp::kMultiply,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
std::string scalar_type = "u32";
CheckMutations(vector_type, scalar_type, vector_type, op,
{other_op, ast::BinaryOp::kMultiply,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
CheckMutations(scalar_type, vector_type, vector_type, op,
{other_op, ast::BinaryOp::kMultiply,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
std::string scalar_type = "f32";
CheckMutations(
vector_type, scalar_type, vector_type, op,
{
other_op, ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
});
CheckMutations(
scalar_type, vector_type, vector_type, op,
{
other_op, ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
});
}
for (std::string square_matrix_type :
{"mat2x2<f32>", "mat3x3<f32>", "mat4x4<f32>"}) {
CheckMutations(square_matrix_type, square_matrix_type, square_matrix_type,
op, {other_op, ast::BinaryOp::kMultiply});
}
for (std::string non_square_matrix_type :
{"mat2x3<f32>", "mat2x4<f32>", "mat3x2<f32>", "mat3x4<f32>",
"mat4x2<f32>", "mat4x3<f32>"}) {
CheckMutations(non_square_matrix_type, non_square_matrix_type,
non_square_matrix_type, op, {other_op});
}
}
}
TEST(ChangeBinaryOperatorTest, Mul) {
for (std::string type : {"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
CheckMutations(
type, type, type, ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract, ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo, ast::BinaryOp::kAnd, ast::BinaryOp::kOr,
ast::BinaryOp::kXor});
}
for (std::string type : {"u32", "vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
CheckMutations(
type, type, type, ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract, ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo, ast::BinaryOp::kAnd, ast::BinaryOp::kOr,
ast::BinaryOp::kXor, ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight});
}
for (std::string type : {"f32", "vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
CheckMutations(type, type, type, ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
std::string scalar_type = "i32";
CheckMutations(vector_type, scalar_type, vector_type,
ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
CheckMutations(scalar_type, vector_type, vector_type,
ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
std::string scalar_type = "u32";
CheckMutations(vector_type, scalar_type, vector_type,
ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
CheckMutations(scalar_type, vector_type, vector_type,
ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
std::string scalar_type = "f32";
CheckMutations(
vector_type, scalar_type, vector_type, ast::BinaryOp::kMultiply,
{
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
});
CheckMutations(
scalar_type, vector_type, vector_type, ast::BinaryOp::kMultiply,
{
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
});
}
for (std::string square_matrix_type :
{"mat2x2<f32>", "mat3x3<f32>", "mat4x4<f32>"}) {
CheckMutations(square_matrix_type, square_matrix_type, square_matrix_type,
ast::BinaryOp::kMultiply,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract});
}
CheckMutations("vec2<f32>", "mat2x2<f32>", "vec2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec2<f32>", "mat3x2<f32>", "vec3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec2<f32>", "mat4x2<f32>", "vec4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x2<f32>", "vec2<f32>", "vec2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x2<f32>", "mat3x2<f32>", "mat3x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x2<f32>", "mat4x2<f32>", "mat4x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x3<f32>", "vec2<f32>", "vec3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x3<f32>", "mat2x2<f32>", "mat2x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x3<f32>", "mat3x2<f32>", "mat3x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x3<f32>", "mat4x2<f32>", "mat4x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x4<f32>", "vec2<f32>", "vec4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x4<f32>", "mat2x2<f32>", "mat2x4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x4<f32>", "mat3x2<f32>", "mat3x4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat2x4<f32>", "mat4x2<f32>", "mat4x4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec3<f32>", "mat2x3<f32>", "vec2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec3<f32>", "mat3x3<f32>", "vec3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec3<f32>", "mat4x3<f32>", "vec4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x2<f32>", "vec3<f32>", "vec2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x2<f32>", "mat2x3<f32>", "mat2x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x2<f32>", "mat3x3<f32>", "mat3x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x2<f32>", "mat4x3<f32>", "mat4x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x3<f32>", "vec3<f32>", "vec3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x3<f32>", "mat2x3<f32>", "mat2x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x3<f32>", "mat4x3<f32>", "mat4x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x4<f32>", "vec3<f32>", "vec4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x4<f32>", "mat2x3<f32>", "mat2x4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x4<f32>", "mat3x3<f32>", "mat3x4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat3x4<f32>", "mat4x3<f32>", "mat4x4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec4<f32>", "mat2x4<f32>", "vec2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec4<f32>", "mat3x4<f32>", "vec3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("vec4<f32>", "mat4x4<f32>", "vec4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x2<f32>", "vec4<f32>", "vec2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x2<f32>", "mat2x4<f32>", "mat2x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x2<f32>", "mat3x4<f32>", "mat3x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x2<f32>", "mat4x4<f32>", "mat4x2<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x3<f32>", "vec4<f32>", "vec3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x3<f32>", "mat2x4<f32>", "mat2x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x3<f32>", "mat3x4<f32>", "mat3x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x3<f32>", "mat4x4<f32>", "mat4x3<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x4<f32>", "vec4<f32>", "vec4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x4<f32>", "mat2x4<f32>", "mat2x4<f32>",
ast::BinaryOp::kMultiply, {});
CheckMutations("mat4x4<f32>", "mat3x4<f32>", "mat3x4<f32>",
ast::BinaryOp::kMultiply, {});
}
TEST(ChangeBinaryOperatorTest, Divide) {
for (std::string type : {"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
CheckMutations(
type, type, type, ast::BinaryOp::kDivide,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo, ast::BinaryOp::kAnd,
ast::BinaryOp::kOr, ast::BinaryOp::kXor});
}
for (std::string type : {"u32", "vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
CheckMutations(
type, type, type, ast::BinaryOp::kDivide,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo, ast::BinaryOp::kAnd,
ast::BinaryOp::kOr, ast::BinaryOp::kXor, ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight});
}
for (std::string type : {"f32", "vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
CheckMutations(type, type, type, ast::BinaryOp::kDivide,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
std::string scalar_type = "i32";
CheckMutations(vector_type, scalar_type, vector_type,
ast::BinaryOp::kDivide,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo});
CheckMutations(scalar_type, vector_type, vector_type,
ast::BinaryOp::kDivide,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
std::string scalar_type = "u32";
CheckMutations(vector_type, scalar_type, vector_type,
ast::BinaryOp::kDivide,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo});
CheckMutations(scalar_type, vector_type, vector_type,
ast::BinaryOp::kDivide,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo});
}
for (std::string vector_type : {"vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
std::string scalar_type = "f32";
CheckMutations(
vector_type, scalar_type, vector_type, ast::BinaryOp::kDivide,
{
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
});
CheckMutations(
scalar_type, vector_type, vector_type, ast::BinaryOp::kDivide,
{
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
});
}
}
// TODO(https://crbug.com/tint/1370): once fixed, combine this with the Divide
// test
TEST(ChangeBinaryOperatorTest, Modulo) {
for (std::string type : {"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
CheckMutations(
type, type, type, ast::BinaryOp::kModulo,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide, ast::BinaryOp::kAnd,
ast::BinaryOp::kOr, ast::BinaryOp::kXor});
}
for (std::string type : {"u32", "vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
CheckMutations(
type, type, type, ast::BinaryOp::kModulo,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide, ast::BinaryOp::kAnd,
ast::BinaryOp::kOr, ast::BinaryOp::kXor, ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight});
}
for (std::string type : {"f32", "vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
CheckMutations(type, type, type, ast::BinaryOp::kModulo,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide});
}
for (std::string vector_type : {"vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
std::string scalar_type = "i32";
CheckMutations(vector_type, scalar_type, vector_type,
ast::BinaryOp::kModulo,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide});
CheckMutations(scalar_type, vector_type, vector_type,
ast::BinaryOp::kModulo,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide});
}
for (std::string vector_type : {"vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
std::string scalar_type = "u32";
CheckMutations(vector_type, scalar_type, vector_type,
ast::BinaryOp::kModulo,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide});
CheckMutations(scalar_type, vector_type, vector_type,
ast::BinaryOp::kModulo,
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide});
}
// TODO(https://crbug.com/tint/1370): mixed float scalars/vectors will be
// added when this test is combined with the Divide test
}
TEST(ChangeBinaryOperatorTest, AndOrXor) {
for (auto op :
{ast::BinaryOp::kAnd, ast::BinaryOp::kOr, ast::BinaryOp::kXor}) {
std::unordered_set<ast::BinaryOp> allowed_replacement_operators_signed{
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo, ast::BinaryOp::kAnd,
ast::BinaryOp::kOr, ast::BinaryOp::kXor};
allowed_replacement_operators_signed.erase(op);
for (std::string type : {"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
CheckMutations(type, type, type, op,
allowed_replacement_operators_signed);
}
std::unordered_set<ast::BinaryOp> allowed_replacement_operators_unsigned{
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo, ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight, ast::BinaryOp::kAnd,
ast::BinaryOp::kOr, ast::BinaryOp::kXor};
allowed_replacement_operators_unsigned.erase(op);
for (std::string type : {"u32", "vec2<u32>", "vec3<u32>", "vec4<u32>"}) {
CheckMutations(type, type, type, op,
allowed_replacement_operators_unsigned);
}
if (op != ast::BinaryOp::kXor) {
for (std::string type :
{"bool", "vec2<bool>", "vec3<bool>", "vec4<bool>"}) {
std::unordered_set<ast::BinaryOp> allowed_replacement_operators_bool{
ast::BinaryOp::kAnd, ast::BinaryOp::kOr, ast::BinaryOp::kEqual,
ast::BinaryOp::kNotEqual};
allowed_replacement_operators_bool.erase(op);
if (type == "bool") {
allowed_replacement_operators_bool.insert(ast::BinaryOp::kLogicalAnd);
allowed_replacement_operators_bool.insert(ast::BinaryOp::kLogicalOr);
}
CheckMutations(type, type, type, op,
allowed_replacement_operators_bool);
}
}
}
}
TEST(ChangeBinaryOperatorTest, EqualNotEqual) {
for (auto op : {ast::BinaryOp::kEqual, ast::BinaryOp::kNotEqual}) {
for (std::string element_type : {"i32", "u32", "f32"}) {
for (size_t element_count = 1; element_count <= 4; element_count++) {
std::stringstream argument_type;
std::stringstream result_type;
if (element_count == 1) {
argument_type << element_type;
result_type << "bool";
} else {
argument_type << "vec" << element_count << "<" << element_type << ">";
result_type << "vec" << element_count << "<bool>";
}
std::unordered_set<ast::BinaryOp> allowed_replacement_operators{
ast::BinaryOp::kLessThan, ast::BinaryOp::kLessThanEqual,
ast::BinaryOp::kGreaterThan, ast::BinaryOp::kGreaterThanEqual,
ast::BinaryOp::kEqual, ast::BinaryOp::kNotEqual};
allowed_replacement_operators.erase(op);
CheckMutations(argument_type.str(), argument_type.str(),
result_type.str(), op, allowed_replacement_operators);
}
}
{
std::unordered_set<ast::BinaryOp> allowed_replacement_operators{
ast::BinaryOp::kLogicalAnd, ast::BinaryOp::kLogicalOr,
ast::BinaryOp::kAnd, ast::BinaryOp::kOr,
ast::BinaryOp::kEqual, ast::BinaryOp::kNotEqual};
allowed_replacement_operators.erase(op);
CheckMutations("bool", "bool", "bool", op, allowed_replacement_operators);
}
for (size_t element_count = 2; element_count <= 4; element_count++) {
std::stringstream argument_and_result_type;
argument_and_result_type << "vec" << element_count << "<bool>";
std::unordered_set<ast::BinaryOp> allowed_replacement_operators{
ast::BinaryOp::kAnd, ast::BinaryOp::kOr, ast::BinaryOp::kEqual,
ast::BinaryOp::kNotEqual};
allowed_replacement_operators.erase(op);
CheckMutations(
argument_and_result_type.str(), argument_and_result_type.str(),
argument_and_result_type.str(), op, allowed_replacement_operators);
}
}
}
TEST(ChangeBinaryOperatorTest,
LessThanLessThanEqualGreaterThanGreaterThanEqual) {
for (auto op :
{ast::BinaryOp::kLessThan, ast::BinaryOp::kLessThanEqual,
ast::BinaryOp::kGreaterThan, ast::BinaryOp::kGreaterThanEqual}) {
for (std::string element_type : {"i32", "u32", "f32"}) {
for (size_t element_count = 1; element_count <= 4; element_count++) {
std::stringstream argument_type;
std::stringstream result_type;
if (element_count == 1) {
argument_type << element_type;
result_type << "bool";
} else {
argument_type << "vec" << element_count << "<" << element_type << ">";
result_type << "vec" << element_count << "<bool>";
}
std::unordered_set<ast::BinaryOp> allowed_replacement_operators{
ast::BinaryOp::kLessThan, ast::BinaryOp::kLessThanEqual,
ast::BinaryOp::kGreaterThan, ast::BinaryOp::kGreaterThanEqual,
ast::BinaryOp::kEqual, ast::BinaryOp::kNotEqual};
allowed_replacement_operators.erase(op);
CheckMutations(argument_type.str(), argument_type.str(),
result_type.str(), op, allowed_replacement_operators);
}
}
}
}
TEST(ChangeBinaryOperatorTest, LogicalAndLogicalOr) {
for (auto op : {ast::BinaryOp::kLogicalAnd, ast::BinaryOp::kLogicalOr}) {
std::unordered_set<ast::BinaryOp> allowed_replacement_operators{
ast::BinaryOp::kLogicalAnd, ast::BinaryOp::kLogicalOr,
ast::BinaryOp::kAnd, ast::BinaryOp::kOr,
ast::BinaryOp::kEqual, ast::BinaryOp::kNotEqual};
allowed_replacement_operators.erase(op);
CheckMutations("bool", "bool", "bool", op, allowed_replacement_operators);
}
}
TEST(ChangeBinaryOperatorTest, ShiftLeftShiftRight) {
for (auto op : {ast::BinaryOp::kShiftLeft, ast::BinaryOp::kShiftRight}) {
for (std::string lhs_element_type : {"i32", "u32"}) {
for (size_t element_count = 1; element_count <= 4; element_count++) {
std::stringstream lhs_and_result_type;
std::stringstream rhs_type;
if (element_count == 1) {
lhs_and_result_type << lhs_element_type;
rhs_type << "u32";
} else {
lhs_and_result_type << "vec" << element_count << "<"
<< lhs_element_type << ">";
rhs_type << "vec" << element_count << "<u32>";
}
std::unordered_set<ast::BinaryOp> allowed_replacement_operators{
ast::BinaryOp::kShiftLeft, ast::BinaryOp::kShiftRight};
allowed_replacement_operators.erase(op);
if (lhs_element_type == "u32") {
allowed_replacement_operators.insert(ast::BinaryOp::kAdd);
allowed_replacement_operators.insert(ast::BinaryOp::kSubtract);
allowed_replacement_operators.insert(ast::BinaryOp::kMultiply);
allowed_replacement_operators.insert(ast::BinaryOp::kDivide);
allowed_replacement_operators.insert(ast::BinaryOp::kModulo);
allowed_replacement_operators.insert(ast::BinaryOp::kAnd);
allowed_replacement_operators.insert(ast::BinaryOp::kOr);
allowed_replacement_operators.insert(ast::BinaryOp::kXor);
}
CheckMutations(lhs_and_result_type.str(), rhs_type.str(),
lhs_and_result_type.str(), op,
allowed_replacement_operators);
}
}
}
}
} // namespace
} // namespace ast_fuzzer
} // namespace fuzzers
} // namespace tint

View File

@ -12,17 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "src/tint/fuzzers/tint_ast_fuzzer/mutations/replace_identifier.h"
#include <string> #include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/mutations/replace_identifier.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/mutator.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/probability_context.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/node_id_map.h"
#include "src/tint/ast/call_statement.h" #include "src/tint/ast/call_statement.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/mutator.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/node_id_map.h"
#include "src/tint/program_builder.h" #include "src/tint/program_builder.h"
#include "src/tint/reader/wgsl/parser.h" #include "src/tint/reader/wgsl/parser.h"
#include "src/tint/writer/wgsl/generator.h" #include "src/tint/writer/wgsl/generator.h"

View File

@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/change_binary_operators.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/replace_identifiers.h" #include "src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/replace_identifiers.h"
#include "src/tint/fuzzers/tint_ast_fuzzer/node_id_map.h" #include "src/tint/fuzzers/tint_ast_fuzzer/node_id_map.h"
@ -45,6 +46,8 @@ MutationFinderList CreateMutationFinders(
bool enable_all_mutations) { bool enable_all_mutations) {
MutationFinderList result; MutationFinderList result;
do { do {
MaybeAddFinder<MutationFinderChangeBinaryOperators>(
enable_all_mutations, probability_context, &result);
MaybeAddFinder<MutationFinderReplaceIdentifiers>( MaybeAddFinder<MutationFinderReplaceIdentifiers>(
enable_all_mutations, probability_context, &result); enable_all_mutations, probability_context, &result);
} while (result.empty()); } while (result.empty());

View File

@ -21,12 +21,15 @@ namespace fuzzers {
namespace ast_fuzzer { namespace ast_fuzzer {
namespace { namespace {
const std::pair<uint32_t, uint32_t> kChanceOfChangingBinaryOperators = {30, 90};
const std::pair<uint32_t, uint32_t> kChanceOfReplacingIdentifiers = {30, 70}; const std::pair<uint32_t, uint32_t> kChanceOfReplacingIdentifiers = {30, 70};
} // namespace } // namespace
ProbabilityContext::ProbabilityContext(RandomGenerator* generator) ProbabilityContext::ProbabilityContext(RandomGenerator* generator)
: generator_(generator), : generator_(generator),
chance_of_changing_binary_operators_(
RandomFromRange(kChanceOfChangingBinaryOperators)),
chance_of_replacing_identifiers_( chance_of_replacing_identifiers_(
RandomFromRange(kChanceOfReplacingIdentifiers)) { RandomFromRange(kChanceOfReplacingIdentifiers)) {
assert(generator != nullptr && "generator must not be nullptr"); assert(generator != nullptr && "generator must not be nullptr");

View File

@ -55,6 +55,11 @@ class ProbabilityContext {
return static_cast<size_t>(generator_->GetUInt64(arr.size())); return static_cast<size_t>(generator_->GetUInt64(arr.size()));
} }
/// @return the probability of replacing some binary operator with another.
uint32_t GetChanceOfChangingBinaryOperators() const {
return chance_of_changing_binary_operators_;
}
/// @return the probability of replacing some identifier with some other one. /// @return the probability of replacing some identifier with some other one.
uint32_t GetChanceOfReplacingIdentifiers() const { uint32_t GetChanceOfReplacingIdentifiers() const {
return chance_of_replacing_identifiers_; return chance_of_replacing_identifiers_;
@ -67,6 +72,7 @@ class ProbabilityContext {
RandomGenerator* generator_; RandomGenerator* generator_;
uint32_t chance_of_changing_binary_operators_;
uint32_t chance_of_replacing_identifiers_; uint32_t chance_of_replacing_identifiers_;
}; };

View File

@ -17,7 +17,10 @@ syntax = "proto3";
package tint.fuzzers.ast_fuzzer.protobufs; package tint.fuzzers.ast_fuzzer.protobufs;
message Mutation { message Mutation {
oneof mutation { MutationReplaceIdentifier replace_identifier = 1; }; oneof mutation {
MutationReplaceIdentifier replace_identifier = 1;
MutationChangeBinaryOperator change_binary_operator = 2;
};
} }
message MutationSequence { message MutationSequence {
@ -46,3 +49,13 @@ message MutationReplaceIdentifier {
// The id of a definition of a variable to replace the use with. // The id of a definition of a variable to replace the use with.
uint32 replacement_id = 2; uint32 replacement_id = 2;
} }
message MutationChangeBinaryOperator {
// This transformation replaces one binary operator with another.
// The id of a binary expression in the AST.
uint32 binary_expr_id = 1;
// A BinaryOp representing the new binary operator.
uint32 new_operator = 2;
}

View File

@ -1723,7 +1723,7 @@ class ProgramBuilder {
/// @param rhs the right hand argument to the division operation /// @param rhs the right hand argument to the division operation
/// @returns a `ast::BinaryExpression` dividing `lhs` by `rhs` /// @returns a `ast::BinaryExpression` dividing `lhs` by `rhs`
template <typename LHS, typename RHS> template <typename LHS, typename RHS>
const ast::Expression* Div(LHS&& lhs, RHS&& rhs) { const ast::BinaryExpression* Div(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kDivide, return create<ast::BinaryExpression>(ast::BinaryOp::kDivide,
Expr(std::forward<LHS>(lhs)), Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
@ -1733,7 +1733,7 @@ class ProgramBuilder {
/// @param rhs the right hand argument to the modulo operation /// @param rhs the right hand argument to the modulo operation
/// @returns a `ast::BinaryExpression` applying modulo of `lhs` by `rhs` /// @returns a `ast::BinaryExpression` applying modulo of `lhs` by `rhs`
template <typename LHS, typename RHS> template <typename LHS, typename RHS>
const ast::Expression* Mod(LHS&& lhs, RHS&& rhs) { const ast::BinaryExpression* Mod(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kModulo, return create<ast::BinaryExpression>(ast::BinaryOp::kModulo,
Expr(std::forward<LHS>(lhs)), Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
@ -1769,6 +1769,26 @@ class ProgramBuilder {
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
} }
/// @param lhs the left hand argument to the logical and operation
/// @param rhs the right hand argument to the logical and operation
/// @returns a `ast::BinaryExpression` of `lhs` && `rhs`
template <typename LHS, typename RHS>
const ast::BinaryExpression* LogicalAnd(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the logical or operation
/// @param rhs the right hand argument to the logical or operation
/// @returns a `ast::BinaryExpression` of `lhs` || `rhs`
template <typename LHS, typename RHS>
const ast::BinaryExpression* LogicalOr(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the greater than operation /// @param lhs the left hand argument to the greater than operation
/// @param rhs the right hand argument to the greater than operation /// @param rhs the right hand argument to the greater than operation
/// @returns a `ast::BinaryExpression` of `lhs` > `rhs` /// @returns a `ast::BinaryExpression` of `lhs` > `rhs`
@ -1819,6 +1839,17 @@ class ProgramBuilder {
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
} }
/// @param lhs the left hand argument to the not-equal expression
/// @param rhs the right hand argument to the not-equal expression
/// @returns a `ast::BinaryExpression` comparing `lhs` equal to `rhs` for
/// disequality
template <typename LHS, typename RHS>
const ast::BinaryExpression* NotEqual(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kNotEqual,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param source the source information /// @param source the source information
/// @param obj the object for the index accessor expression /// @param obj the object for the index accessor expression
/// @param idx the index argument for the index accessor expression /// @param idx the index argument for the index accessor expression

View File

@ -80,6 +80,12 @@ bool Type::is_float_matrix() const {
return Is([](const Matrix* m) { return m->type()->is_float_scalar(); }); return Is([](const Matrix* m) { return m->type()->is_float_scalar(); });
} }
bool Type::is_square_float_matrix() const {
return Is([](const Matrix* m) {
return m->type()->is_float_scalar() && m->rows() == m->columns();
});
}
bool Type::is_float_vector() const { bool Type::is_float_vector() const {
return Is([](const Vector* v) { return v->type()->is_float_scalar(); }); return Is([](const Vector* v) { return v->type()->is_float_scalar(); });
} }

View File

@ -77,6 +77,8 @@ class Type : public Castable<Type, Node> {
bool is_float_scalar() const; bool is_float_scalar() const;
/// @returns true if this type is a float matrix /// @returns true if this type is a float matrix
bool is_float_matrix() const; bool is_float_matrix() const;
/// @returns true if this type is a square float matrix
bool is_square_float_matrix() const;
/// @returns true if this type is a float vector /// @returns true if this type is a float vector
bool is_float_vector() const; bool is_float_vector() const;
/// @returns true if this type is a float scalar or vector /// @returns true if this type is a float scalar or vector