acvm/compiler/optimizers/
mod.rsuse std::collections::BTreeMap;
use acir::{
AcirField,
circuit::{Circuit, Opcode, brillig::BrilligFunctionId},
};
mod general;
mod merge_expressions;
mod redundant_range;
mod unused_memory;
pub(crate) use general::GeneralOptimizer;
pub(crate) use merge_expressions::MergeExpressionsOptimizer;
pub(crate) use redundant_range::RangeOptimizer;
use tracing::info;
use self::unused_memory::UnusedMemoryOptimizer;
use super::{AcirTransformationMap, transform_assert_messages};
pub fn optimize<F: AcirField>(
acir: Circuit<F>,
brillig_side_effects: &BTreeMap<BrilligFunctionId, bool>,
) -> (Circuit<F>, AcirTransformationMap) {
let acir_opcode_positions = (0..acir.opcodes.len()).collect();
let (mut acir, new_opcode_positions) =
optimize_internal(acir, acir_opcode_positions, brillig_side_effects);
let transformation_map = AcirTransformationMap::new(&new_opcode_positions);
acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map);
(acir, transformation_map)
}
#[tracing::instrument(level = "trace", name = "optimize_acir" skip(acir, acir_opcode_positions))]
pub(super) fn optimize_internal<F: AcirField>(
acir: Circuit<F>,
acir_opcode_positions: Vec<usize>,
brillig_side_effects: &BTreeMap<BrilligFunctionId, bool>,
) -> (Circuit<F>, Vec<usize>) {
if acir.opcodes.len() == 1 && matches!(acir.opcodes[0], Opcode::BrilligCall { .. }) {
info!("Program is fully unconstrained, skipping optimization pass");
return (acir, acir_opcode_positions);
}
info!("Number of opcodes before: {}", acir.opcodes.len());
let opcodes: Vec<Opcode<F>> = acir
.opcodes
.into_iter()
.map(|opcode| {
if let Opcode::AssertZero(arith_expr) = opcode {
Opcode::AssertZero(GeneralOptimizer::optimize(arith_expr))
} else {
opcode
}
})
.collect();
let acir = Circuit { opcodes, ..acir };
let memory_optimizer = UnusedMemoryOptimizer::new(acir);
let (acir, acir_opcode_positions) =
memory_optimizer.remove_unused_memory_initializations(acir_opcode_positions);
let range_optimizer = RangeOptimizer::new(acir, brillig_side_effects);
let (acir, acir_opcode_positions) =
range_optimizer.replace_redundant_ranges(acir_opcode_positions);
info!("Number of opcodes after: {}", acir.opcodes.len());
(acir, acir_opcode_positions)
}