acvm/compiler/optimizers/
general.rsuse acir::{
AcirField,
native_types::{Expression, Witness},
};
use indexmap::IndexMap;
pub(crate) struct GeneralOptimizer;
impl GeneralOptimizer {
pub(crate) fn optimize<F: AcirField>(opcode: Expression<F>) -> Expression<F> {
let opcode = simplify_mul_terms(opcode);
simplify_linear_terms(opcode)
}
}
fn simplify_mul_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
let mut hash_map: IndexMap<(Witness, Witness), F> = IndexMap::new();
for (scale, w_l, w_r) in gate.mul_terms.into_iter() {
let mut pair = [w_l, w_r];
pair.sort();
*hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale;
}
gate.mul_terms = hash_map
.into_iter()
.filter(|(_, scale)| !scale.is_zero())
.map(|((w_l, w_r), scale)| (scale, w_l, w_r))
.collect();
gate
}
fn simplify_linear_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
let mut hash_map: IndexMap<Witness, F> = IndexMap::new();
for (scale, witness) in gate.linear_combinations.into_iter() {
*hash_map.entry(witness).or_insert_with(F::zero) += scale;
}
gate.linear_combinations = hash_map
.into_iter()
.filter(|(_, scale)| !scale.is_zero())
.map(|(witness, scale)| (scale, witness))
.collect();
gate
}
#[cfg(test)]
mod tests {
use acir::{
FieldElement,
circuit::{Circuit, Opcode},
};
use crate::{assert_circuit_snapshot, compiler::optimizers::GeneralOptimizer};
fn optimize(circuit: Circuit<FieldElement>) -> Circuit<FieldElement> {
let opcodes = circuit
.clone()
.opcodes
.into_iter()
.map(|opcode| {
if let Opcode::AssertZero(arith_expr) = opcode {
Opcode::AssertZero(GeneralOptimizer::optimize(arith_expr))
} else {
opcode
}
})
.collect();
let mut optimized_circuit = circuit;
optimized_circuit.opcodes = opcodes;
optimized_circuit
}
#[test]
fn removes_zero_coefficients_from_mul_terms() {
let src = "
private parameters: [w0, w1]
public parameters: []
return values: []
// The first multiplication should be removed
ASSERT 0*w0*w1 + w0*w1 = 0
";
let circuit = Circuit::from_str(src).unwrap();
let optimized_circuit = optimize(circuit);
assert_circuit_snapshot!(optimized_circuit, @r"
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT 0 = w0*w1
");
}
#[test]
fn removes_zero_coefficients_from_linear_terms() {
let src = "
private parameters: [w0, w1]
public parameters: []
return values: []
// The first linear combination should be removed
ASSERT 0*w0 + w1 = 0
";
let circuit = Circuit::from_str(src).unwrap();
let optimized_circuit = optimize(circuit);
assert_circuit_snapshot!(optimized_circuit, @r"
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT w1 = 0
");
}
#[test]
fn simplifies_mul_terms() {
let src = "
private parameters: [w0, w1]
public parameters: []
return values: []
// There are all mul terms with the same variables so we should end up with just one
// that is the sum of all the coefficients
ASSERT 2*w0*w1 + 3*w1*w0 + 4*w0*w1 = 0
";
let circuit = Circuit::from_str(src).unwrap();
let optimized_circuit = optimize(circuit);
assert_circuit_snapshot!(optimized_circuit, @r"
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT 0 = 9*w0*w1
");
}
#[test]
fn removes_zero_coefficients_after_simplifying_mul_terms() {
let src = "
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT 2*w0*w1 + 3*w1*w0 - 5*w0*w1 = 0
";
let circuit = Circuit::from_str(src).unwrap();
let optimized_circuit = optimize(circuit);
assert_circuit_snapshot!(optimized_circuit, @r"
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT 0 = 0
");
}
#[test]
fn simplifies_linear_terms() {
let src = "
private parameters: [w0, w1]
public parameters: []
return values: []
// These are all linear terms with the same variable so we should end up with just one
// that is the sum of all the coefficients
ASSERT w0 + 2*w0 + 3*w0 = 0
";
let circuit = Circuit::from_str(src).unwrap();
let optimized_circuit = optimize(circuit);
assert_circuit_snapshot!(optimized_circuit, @r"
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT 0 = 6*w0
");
}
#[test]
fn removes_zero_coefficients_after_simplifying_linear_terms() {
let src = "
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT w0 + 2*w0 - 3*w0 = 0
";
let circuit = Circuit::from_str(src).unwrap();
let optimized_circuit = optimize(circuit);
assert_circuit_snapshot!(optimized_circuit, @r"
private parameters: [w0, w1]
public parameters: []
return values: []
ASSERT 0 = 0
");
}
}