acvm/compiler/optimizers/
unused_memory.rs

1use acir::{
2    AcirField,
3    circuit::{Circuit, Opcode, brillig::BrilligInputs, opcodes::BlockId},
4};
5use std::collections::HashSet;
6
7/// `UnusedMemoryOptimizer` will remove initializations of memory blocks which are unused.
8/// A first pass collects all memory blocks which are initialized but discards the ones
9/// which are used in a MemoryOp or as input to a BrilligCall.
10/// The second pass removes the opcodes tagged as unused by the first pass.
11pub(crate) struct UnusedMemoryOptimizer<F: AcirField> {
12    unused_memory_initializations: HashSet<BlockId>,
13    circuit: Circuit<F>,
14}
15
16impl<F: AcirField> UnusedMemoryOptimizer<F> {
17    /// Creates a new `UnusedMemoryOptimizer ` by collecting unused memory init
18    /// opcodes from `Circuit`.
19    pub(crate) fn new(circuit: Circuit<F>) -> Self {
20        let unused_memory_initializations = Self::collect_unused_memory_initializations(&circuit);
21        Self { circuit, unused_memory_initializations }
22    }
23
24    /// Creates a set of ids for memory blocks for which no [`Opcode::MemoryOp`]s exist.
25    ///
26    /// These memory blocks can be safely removed.
27    fn collect_unused_memory_initializations(circuit: &Circuit<F>) -> HashSet<BlockId> {
28        let mut unused_memory_initialization = HashSet::new();
29
30        for opcode in &circuit.opcodes {
31            match opcode {
32                Opcode::MemoryInit { block_id, .. } => {
33                    unused_memory_initialization.insert(*block_id);
34                }
35                Opcode::MemoryOp { block_id, .. } => {
36                    unused_memory_initialization.remove(block_id);
37                }
38                Opcode::BrilligCall { inputs, .. } => {
39                    for input in inputs {
40                        if let BrilligInputs::MemoryArray(block) = input {
41                            unused_memory_initialization.remove(block);
42                        }
43                    }
44                }
45                _ => (),
46            }
47        }
48        unused_memory_initialization
49    }
50
51    /// Returns a `Circuit` where [`Opcode::MemoryInit`]s for unused memory blocks are dropped.
52    pub(crate) fn remove_unused_memory_initializations(
53        self,
54        order_list: Vec<usize>,
55    ) -> (Circuit<F>, Vec<usize>) {
56        let mut new_order_list = Vec::with_capacity(order_list.len());
57        let mut optimized_opcodes = Vec::with_capacity(self.circuit.opcodes.len());
58        for (idx, opcode) in self.circuit.opcodes.into_iter().enumerate() {
59            match opcode {
60                Opcode::MemoryInit { block_id, block_type, .. }
61                    if !block_type.is_databus()
62                        && self.unused_memory_initializations.contains(&block_id) =>
63                {
64                    // Drop opcode
65                }
66                _ => {
67                    new_order_list.push(order_list[idx]);
68                    optimized_opcodes.push(opcode);
69                }
70            }
71        }
72
73        (Circuit { opcodes: optimized_opcodes, ..self.circuit }, new_order_list)
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use crate::{assert_circuit_snapshot, compiler::CircuitSimulator};
80
81    use super::*;
82
83    #[test]
84    fn unused_memory_is_removed() {
85        let src = "
86        private parameters: [w0, w1]
87        public parameters: []
88        return values: [w2]
89        INIT b0 = [w0, w1]
90        ASSERT w0 - w1 - w2 = 0
91        ";
92        let circuit = Circuit::from_str(src).unwrap();
93        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
94        let unused_memory = UnusedMemoryOptimizer::new(circuit);
95        assert_eq!(unused_memory.unused_memory_initializations.len(), 1);
96        let (circuit, _) = unused_memory.remove_unused_memory_initializations(vec![0, 1]);
97        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
98        assert_circuit_snapshot!(circuit, @r"
99        private parameters: [w0, w1]
100        public parameters: []
101        return values: [w2]
102        ASSERT w2 = w0 - w1
103        ");
104    }
105
106    #[test]
107    fn databus_is_not_removed() {
108        let src = "
109        private parameters: [w0, w1]
110        public parameters: []
111        return values: [w2]
112        INIT RETURNDATA b0 = [w0, w1]
113        ASSERT w2 = w0 - w1
114        ";
115        let circuit = Circuit::from_str(src).unwrap();
116        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
117        let unused_memory = UnusedMemoryOptimizer::new(circuit.clone());
118        assert_eq!(unused_memory.unused_memory_initializations.len(), 1);
119        let (optimized_circuit, _) = unused_memory.remove_unused_memory_initializations(vec![0, 1]);
120        assert!(CircuitSimulator::check_circuit(&optimized_circuit).is_none());
121        assert_eq!(optimized_circuit, circuit);
122    }
123}