acvm/compiler/optimizers/common_subexpression/
merge_expressions.rs

1use std::collections::{BTreeMap, BTreeSet, HashMap};
2
3use acir::{
4    AcirField,
5    circuit::{
6        Circuit, Opcode,
7        brillig::{BrilligInputs, BrilligOutputs},
8        opcodes::BlockId,
9    },
10    native_types::{Expression, Witness},
11};
12
13use crate::compiler::{CircuitSimulator, optimizers::GeneralOptimizer};
14
15pub(crate) struct MergeExpressionsOptimizer<F: AcirField> {
16    resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,
17    modified_gates: HashMap<usize, Opcode<F>>,
18    deleted_gates: BTreeSet<usize>,
19}
20
21impl<F: AcirField> MergeExpressionsOptimizer<F> {
22    pub(crate) fn new() -> Self {
23        MergeExpressionsOptimizer {
24            resolved_blocks: HashMap::new(),
25            modified_gates: HashMap::new(),
26            deleted_gates: BTreeSet::new(),
27        }
28    }
29
30    /// This pass analyzes the circuit and identifies intermediate variables that are
31    /// only used in two AssertZero opcodes. It then merges the opcode which produces the
32    /// intermediate variable into the second one that uses it
33    ///
34    /// The first pass maps witnesses to the indices of the opcodes using them.
35    /// Public inputs are not considered because they cannot be simplified.
36    /// Witnesses used by MemoryInit opcodes are put in a separate map and marked as used by a Brillig call
37    /// if the memory block is an input to the call.
38    ///
39    /// The second pass looks for AssertZero opcodes having a witness which is only used by another arithmetic opcode.
40    /// In that case, the opcode with the smallest index is merged into the other one via Gaussian elimination.
41    /// For instance, if we have 'w1' used only by these two opcodes,
42    /// `5*w2*w3` and `w1`:
43    /// w2*w3 + 2*w2 + w1 + w3 = 0   // This opcode 'defines' the variable w1
44    /// 2*w3*w4 + w1 + w4 = 0        // which is only used here
45    ///
46    /// For w1 we can say:
47    /// w1 = -1/2*w2*w3 - w2 - 1/2*w3
48    ///
49    /// Then we will remove the first one and modify the second one like this:
50    /// 2*w3*w4 + w4 - w2 - 1/2*w3 - 1/2*w2*w3 = 0
51    ///
52    /// Pre-condition:
53    /// - This pass is relevant for backends that can handle unlimited width and
54    ///   Plonk-ish backends. Although they have a limited width, they can potentially
55    ///   handle expressions with large linear combinations using 'big-add' gates.
56    /// - The CSAT pass should have been run prior to this one.
57    pub(crate) fn eliminate_intermediate_variable(
58        &mut self,
59        circuit: &Circuit<F>,
60        acir_opcode_positions: Vec<usize>,
61    ) -> (Vec<Opcode<F>>, Vec<usize>) {
62        // Initialization
63        self.modified_gates.clear();
64        self.deleted_gates.clear();
65        self.resolved_blocks.clear();
66
67        // Keep track, for each witness, of the gates that use it
68        let circuit_io: BTreeSet<Witness> =
69            circuit.circuit_arguments().union(&circuit.public_inputs().0).copied().collect();
70
71        let mut used_witnesses: BTreeMap<Witness, BTreeSet<usize>> = BTreeMap::new();
72        for (i, opcode) in circuit.opcodes.iter().enumerate() {
73            let witnesses = self.witness_inputs(opcode);
74            if let Opcode::MemoryInit { block_id, .. } = opcode {
75                self.resolved_blocks.insert(*block_id, witnesses.clone());
76            }
77            for w in witnesses {
78                // We do not simplify circuit inputs and outputs
79                if !circuit_io.contains(&w) {
80                    used_witnesses.entry(w).or_default().insert(i);
81                }
82            }
83        }
84
85        // For each opcode, try to get a target opcode to merge with
86        for (op1, opcode) in circuit.opcodes.iter().enumerate() {
87            if !matches!(opcode, Opcode::AssertZero(_)) {
88                continue;
89            }
90            if let Some(opcode) = self.get_opcode(op1, circuit) {
91                let input_witnesses = self.witness_inputs(&opcode);
92                for w in input_witnesses {
93                    let Some(gates_using_w) = used_witnesses.get(&w) else {
94                        continue;
95                    };
96                    // We only consider witness which are used in exactly two arithmetic gates
97                    if gates_using_w.len() == 2 {
98                        let first = *gates_using_w.first().expect("gates_using_w.len == 2");
99                        let second = *gates_using_w.last().expect("gates_using_w.len == 2");
100                        let op2 = if second == op1 {
101                            first
102                        } else {
103                            // sanity check
104                            assert!(op1 == first);
105                            second
106                        };
107
108                        // Merge the opcode with smaller index into the other one
109                        // by updating modified_gates/deleted_gates/used_witnesses
110                        // returns false if it could not merge them
111                        if op1 != op2 {
112                            let (source, target) = if op1 < op2 { (op1, op2) } else { (op2, op1) };
113                            let source_opcode = self.get_opcode(source, circuit);
114                            let target_opcode = self.get_opcode(target, circuit);
115
116                            if let (
117                                Some(Opcode::AssertZero(expr_use)),
118                                Some(Opcode::AssertZero(expr_define)),
119                            ) = (target_opcode, source_opcode)
120                                && let Some(expr) =
121                                    Self::merge_expression(&expr_use, &expr_define, w)
122                            {
123                                self.modified_gates.insert(target, Opcode::AssertZero(expr));
124                                self.deleted_gates.insert(source);
125                                // Update the 'used_witnesses' map to account for the merge.
126                                let witness_list = CircuitSimulator::expr_witness(&expr_use);
127                                let witness_list = witness_list
128                                    .chain(CircuitSimulator::expr_witness(&expr_define));
129
130                                for w2 in witness_list {
131                                    if !circuit_io.contains(&w2) {
132                                        used_witnesses.entry(w2).and_modify(|v| {
133                                            v.insert(target);
134                                            v.remove(&source);
135                                        });
136                                    }
137                                }
138                                // We need to stop here and continue with the next opcode
139                                // because the merge invalidates the current opcode.
140                                break;
141                            }
142                        }
143                    }
144                }
145            }
146        }
147
148        // Construct the new circuit from modified/deleted gates
149        let mut new_circuit = Vec::new();
150        let mut new_acir_opcode_positions = Vec::new();
151
152        for (i, opcode_position) in acir_opcode_positions.iter().enumerate() {
153            if let Some(opcode) = self.get_opcode(i, circuit) {
154                new_circuit.push(opcode);
155                new_acir_opcode_positions.push(*opcode_position);
156            }
157        }
158        (new_circuit, new_acir_opcode_positions)
159    }
160
161    fn for_each_brillig_input_witness(&self, input: &BrilligInputs<F>, mut f: impl FnMut(Witness)) {
162        match input {
163            BrilligInputs::Single(expr) => {
164                for witness in CircuitSimulator::expr_witness(expr) {
165                    f(witness);
166                }
167            }
168            BrilligInputs::Array(exprs) => {
169                for expr in exprs {
170                    for witness in CircuitSimulator::expr_witness(expr) {
171                        f(witness);
172                    }
173                }
174            }
175            BrilligInputs::MemoryArray(block_id) => {
176                for witness in self.resolved_blocks.get(block_id).expect("Unknown block id") {
177                    f(*witness);
178                }
179            }
180        }
181    }
182
183    fn for_each_brillig_output_witness(&self, output: &BrilligOutputs, mut f: impl FnMut(Witness)) {
184        match output {
185            BrilligOutputs::Simple(witness) => f(*witness),
186            BrilligOutputs::Array(witnesses) => {
187                for witness in witnesses {
188                    f(*witness);
189                }
190            }
191        }
192    }
193
194    // Returns the input witnesses used by the opcode
195    fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
196        match opcode {
197            Opcode::AssertZero(expr) => CircuitSimulator::expr_witness(expr).collect(),
198            Opcode::BlackBoxFuncCall(bb_func) => {
199                let mut witnesses = bb_func.get_input_witnesses();
200                witnesses.extend(bb_func.get_outputs_vec());
201                if let Some(w) = bb_func.get_predicate() {
202                    witnesses.insert(w);
203                }
204                witnesses
205            }
206            Opcode::MemoryOp { block_id: _, op } => CircuitSimulator::expr_witness(&op.operation)
207                .chain(CircuitSimulator::expr_witness(&op.index))
208                .chain(CircuitSimulator::expr_witness(&op.value))
209                .collect(),
210
211            Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
212                init.iter().copied().collect()
213            }
214            Opcode::BrilligCall { inputs, outputs, predicate, .. } => {
215                let mut witnesses = BTreeSet::new();
216                for i in inputs {
217                    self.for_each_brillig_input_witness(i, |witness| {
218                        witnesses.insert(witness);
219                    });
220                }
221                witnesses.extend(CircuitSimulator::expr_witness(predicate));
222                for i in outputs {
223                    self.for_each_brillig_output_witness(i, |witness| {
224                        witnesses.insert(witness);
225                    });
226                }
227                witnesses
228            }
229            Opcode::Call { id: _, inputs, outputs, predicate } => {
230                let mut witnesses: BTreeSet<Witness> = inputs.iter().copied().collect();
231                witnesses.extend(outputs);
232                witnesses.extend(CircuitSimulator::expr_witness(predicate));
233                witnesses
234            }
235        }
236    }
237
238    // Merge 'expr' into 'target' via Gaussian elimination on 'w'
239    // Returns None if the expressions cannot be merged
240    fn merge_expression(
241        target: &Expression<F>,
242        expr: &Expression<F>,
243        witness: Witness,
244    ) -> Option<Expression<F>> {
245        // Check that the witness is not part of multiplication terms
246        for m in &target.mul_terms {
247            if m.1 == witness || m.2 == witness {
248                return None;
249            }
250        }
251        for m in &expr.mul_terms {
252            if m.1 == witness || m.2 == witness {
253                return None;
254            }
255        }
256
257        for k in &target.linear_combinations {
258            if k.1 == witness {
259                for i in &expr.linear_combinations {
260                    if i.1 == witness {
261                        assert!(
262                            i.0 != F::zero(),
263                            "merge_expression: attempting to divide k.0 by F::zero"
264                        );
265                        let expr = target.add_mul(-(k.0 / i.0), expr);
266                        let expr = GeneralOptimizer::optimize(expr);
267                        return Some(expr);
268                    }
269                }
270            }
271        }
272        None
273    }
274
275    /// Returns the 'updated' opcode at the given index in the circuit
276    /// The modifications to the circuits are stored with 'deleted_gates' and 'modified_gates'
277    /// These structures are used to give the 'updated' opcode.
278    /// For instance, if the opcode has been deleted inside 'deleted_gates', then it returns None.
279    fn get_opcode(&self, index: usize, circuit: &Circuit<F>) -> Option<Opcode<F>> {
280        if self.deleted_gates.contains(&index) {
281            return None;
282        }
283        self.modified_gates.get(&index).or(circuit.opcodes.get(index)).cloned()
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use crate::{
290        assert_circuit_snapshot,
291        compiler::{
292            CircuitSimulator,
293            optimizers::common_subexpression::merge_expressions::MergeExpressionsOptimizer,
294        },
295    };
296    use acir::{
297        AcirField, FieldElement,
298        circuit::Circuit,
299        native_types::{Expression, Witness},
300    };
301
302    fn merge_expressions(circuit: Circuit<FieldElement>) -> Circuit<FieldElement> {
303        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
304        let mut merge_optimizer = MergeExpressionsOptimizer::new();
305        let acir_opcode_positions = vec![0; 20];
306        let (opcodes, _) =
307            merge_optimizer.eliminate_intermediate_variable(&circuit, acir_opcode_positions);
308        let mut optimized_circuit = circuit;
309        optimized_circuit.opcodes = opcodes;
310
311        // check that the circuit is still valid after optimization
312        assert!(CircuitSimulator::check_circuit(&optimized_circuit).is_none());
313        optimized_circuit
314    }
315
316    #[test]
317    fn merges_expressions() {
318        let src = "
319        private parameters: [w0]
320        public parameters: []
321        return values: [w2]
322        ASSERT 2*w1 = w0 + 5
323        ASSERT w2 = 4*w1 + 4
324        ";
325        let circuit = Circuit::from_str(src).unwrap();
326        let optimized_circuit = merge_expressions(circuit);
327        assert_circuit_snapshot!(optimized_circuit, @r"
328        private parameters: [w0]
329        public parameters: []
330        return values: [w2]
331        ASSERT w2 = 2*w0 + 14
332        ");
333    }
334
335    #[test]
336    fn does_not_eliminate_witnesses_returned_from_brillig() {
337        let src = "
338        private parameters: [w0]
339        public parameters: []
340        return values: []
341        BRILLIG CALL func: 0, predicate: 1, inputs: [], outputs: [w1]
342        ASSERT 2*w0 + 3*w1 + w2 + 1 = 0
343        ASSERT 2*w0 + 2*w1 + w5 + 1 = 0
344        ";
345        let circuit = Circuit::from_str(src).unwrap();
346        let optimized_circuit = merge_expressions(circuit.clone());
347        assert_eq!(circuit, optimized_circuit);
348    }
349
350    #[test]
351    fn does_not_eliminate_witnesses_returned_from_circuit() {
352        let src = "
353        private parameters: [w0]
354        public parameters: []
355        return values: [w1, w2]
356        ASSERT -w0*w0 + w1 = 0
357        ASSERT -w1 + w2 = 0
358        ";
359        let circuit = Circuit::from_str(src).unwrap();
360        let optimized_circuit = merge_expressions(circuit.clone());
361        assert_eq!(circuit, optimized_circuit);
362    }
363
364    #[test]
365    fn does_not_attempt_to_merge_into_previous_opcodes() {
366        let src = "
367        private parameters: [w0, w1]
368        public parameters: []
369        return values: []
370        ASSERT w0*w0 - w4 = 0
371        ASSERT w0*w1 + w5 = 0
372        ASSERT -w2 + w4 + w5 = 0
373        ASSERT w2 - w3 + w4 + w5 = 0
374        BLACKBOX::RANGE input: w3, bits: 32
375        ";
376        let circuit = Circuit::from_str(src).unwrap();
377
378        let optimized_circuit = merge_expressions(circuit);
379        assert_circuit_snapshot!(optimized_circuit, @r"
380        private parameters: [w0, w1]
381        public parameters: []
382        return values: []
383        ASSERT w5 = -w0*w1
384        ASSERT w3 = 2*w0*w0 + 2*w5
385        BLACKBOX::RANGE input: w3, bits: 32
386        ");
387    }
388
389    #[test]
390    fn takes_blackbox_opcode_outputs_into_account() {
391        // Regression test for https://github.com/noir-lang/noir/issues/6527
392        // Previously we would not track the usage of witness 4 in the output of the blackbox function.
393        // We would then merge the final two opcodes losing the check that the brillig call must match
394        // with `w0 ^ w1`.
395        let src = "
396        private parameters: [w0, w1]
397        public parameters: []
398        return values: [w2]
399        BRILLIG CALL func: 0, predicate: 1, inputs: [], outputs: [w3]
400        BLACKBOX::AND lhs: w0, rhs: w1, output: w4, bits: 8
401        ASSERT w3 - w4 = 0
402        ASSERT -w2 + w4 = 0
403        ";
404        let circuit = Circuit::from_str(src).unwrap();
405        let optimized_circuit = merge_expressions(circuit.clone());
406        assert_eq!(circuit, optimized_circuit);
407    }
408
409    #[test]
410    #[should_panic(expected = "merge_expression: attempting to divide k.0 by F::zero")]
411    fn merge_expression_on_zero_linear_combination_panics() {
412        let opcode_a = Expression {
413            mul_terms: vec![],
414            linear_combinations: vec![(FieldElement::one(), Witness(0))],
415            q_c: FieldElement::zero(),
416        };
417        let opcode_b = Expression {
418            mul_terms: vec![],
419            linear_combinations: vec![(FieldElement::zero(), Witness(0))],
420            q_c: FieldElement::zero(),
421        };
422        assert_eq!(
423            MergeExpressionsOptimizer::merge_expression(&opcode_a, &opcode_b, Witness(0),),
424            Some(opcode_a)
425        );
426    }
427
428    #[test]
429    fn does_not_eliminate_witnesses_used_in_brillig_call_predicates() {
430        let src = "
431        private parameters: [w2]
432        public parameters: [w0, w1]
433        return values: [w3]
434        BLACKBOX::RANGE input: w0, bits: 1
435        BLACKBOX::RANGE input: w1, bits: 1
436        BLACKBOX::RANGE input: w2, bits: 1
437        ASSERT w4 = w0*w1
438        ASSERT w5 = -w2 + 1
439        BRILLIG CALL func: 0, predicate: w4*w5, inputs: [w2], outputs: [w6]
440        ASSERT w3 = -w5 + 1
441        ";
442        let circuit = Circuit::from_str(src).unwrap();
443        let optimized_circuit = merge_expressions(circuit.clone());
444        assert_eq!(circuit, optimized_circuit);
445    }
446}