acvm/compiler/optimizers/common_subexpression/
merge_expressions.rs1use std::collections::{BTreeMap, BTreeSet, HashMap};
2
3use acir::{
4 AcirField,
5 circuit::{
6 Circuit, Opcode,
7 brillig::{BrilligInputs, BrilligOutputs},
8 opcodes::BlockId,
9 },
10 native_types::{Expression, Witness},
11};
12
13use crate::compiler::{CircuitSimulator, optimizers::GeneralOptimizer};
14
15pub(crate) struct MergeExpressionsOptimizer<F: AcirField> {
16 resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,
17 modified_gates: HashMap<usize, Opcode<F>>,
18 deleted_gates: BTreeSet<usize>,
19}
20
21impl<F: AcirField> MergeExpressionsOptimizer<F> {
22 pub(crate) fn new() -> Self {
23 MergeExpressionsOptimizer {
24 resolved_blocks: HashMap::new(),
25 modified_gates: HashMap::new(),
26 deleted_gates: BTreeSet::new(),
27 }
28 }
29
30 pub(crate) fn eliminate_intermediate_variable(
58 &mut self,
59 circuit: &Circuit<F>,
60 acir_opcode_positions: Vec<usize>,
61 ) -> (Vec<Opcode<F>>, Vec<usize>) {
62 self.modified_gates.clear();
64 self.deleted_gates.clear();
65 self.resolved_blocks.clear();
66
67 let circuit_io: BTreeSet<Witness> =
69 circuit.circuit_arguments().union(&circuit.public_inputs().0).copied().collect();
70
71 let mut used_witnesses: BTreeMap<Witness, BTreeSet<usize>> = BTreeMap::new();
72 for (i, opcode) in circuit.opcodes.iter().enumerate() {
73 let witnesses = self.witness_inputs(opcode);
74 if let Opcode::MemoryInit { block_id, .. } = opcode {
75 self.resolved_blocks.insert(*block_id, witnesses.clone());
76 }
77 for w in witnesses {
78 if !circuit_io.contains(&w) {
80 used_witnesses.entry(w).or_default().insert(i);
81 }
82 }
83 }
84
85 for (op1, opcode) in circuit.opcodes.iter().enumerate() {
87 if !matches!(opcode, Opcode::AssertZero(_)) {
88 continue;
89 }
90 if let Some(opcode) = self.get_opcode(op1, circuit) {
91 let input_witnesses = self.witness_inputs(&opcode);
92 for w in input_witnesses {
93 let Some(gates_using_w) = used_witnesses.get(&w) else {
94 continue;
95 };
96 if gates_using_w.len() == 2 {
98 let first = *gates_using_w.first().expect("gates_using_w.len == 2");
99 let second = *gates_using_w.last().expect("gates_using_w.len == 2");
100 let op2 = if second == op1 {
101 first
102 } else {
103 assert!(op1 == first);
105 second
106 };
107
108 if op1 != op2 {
112 let (source, target) = if op1 < op2 { (op1, op2) } else { (op2, op1) };
113 let source_opcode = self.get_opcode(source, circuit);
114 let target_opcode = self.get_opcode(target, circuit);
115
116 if let (
117 Some(Opcode::AssertZero(expr_use)),
118 Some(Opcode::AssertZero(expr_define)),
119 ) = (target_opcode, source_opcode)
120 && let Some(expr) =
121 Self::merge_expression(&expr_use, &expr_define, w)
122 {
123 self.modified_gates.insert(target, Opcode::AssertZero(expr));
124 self.deleted_gates.insert(source);
125 let witness_list = CircuitSimulator::expr_witness(&expr_use);
127 let witness_list = witness_list
128 .chain(CircuitSimulator::expr_witness(&expr_define));
129
130 for w2 in witness_list {
131 if !circuit_io.contains(&w2) {
132 used_witnesses.entry(w2).and_modify(|v| {
133 v.insert(target);
134 v.remove(&source);
135 });
136 }
137 }
138 break;
141 }
142 }
143 }
144 }
145 }
146 }
147
148 let mut new_circuit = Vec::new();
150 let mut new_acir_opcode_positions = Vec::new();
151
152 for (i, opcode_position) in acir_opcode_positions.iter().enumerate() {
153 if let Some(opcode) = self.get_opcode(i, circuit) {
154 new_circuit.push(opcode);
155 new_acir_opcode_positions.push(*opcode_position);
156 }
157 }
158 (new_circuit, new_acir_opcode_positions)
159 }
160
161 fn for_each_brillig_input_witness(&self, input: &BrilligInputs<F>, mut f: impl FnMut(Witness)) {
162 match input {
163 BrilligInputs::Single(expr) => {
164 for witness in CircuitSimulator::expr_witness(expr) {
165 f(witness);
166 }
167 }
168 BrilligInputs::Array(exprs) => {
169 for expr in exprs {
170 for witness in CircuitSimulator::expr_witness(expr) {
171 f(witness);
172 }
173 }
174 }
175 BrilligInputs::MemoryArray(block_id) => {
176 for witness in self.resolved_blocks.get(block_id).expect("Unknown block id") {
177 f(*witness);
178 }
179 }
180 }
181 }
182
183 fn for_each_brillig_output_witness(&self, output: &BrilligOutputs, mut f: impl FnMut(Witness)) {
184 match output {
185 BrilligOutputs::Simple(witness) => f(*witness),
186 BrilligOutputs::Array(witnesses) => {
187 for witness in witnesses {
188 f(*witness);
189 }
190 }
191 }
192 }
193
194 fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
196 match opcode {
197 Opcode::AssertZero(expr) => CircuitSimulator::expr_witness(expr).collect(),
198 Opcode::BlackBoxFuncCall(bb_func) => {
199 let mut witnesses = bb_func.get_input_witnesses();
200 witnesses.extend(bb_func.get_outputs_vec());
201 if let Some(w) = bb_func.get_predicate() {
202 witnesses.insert(w);
203 }
204 witnesses
205 }
206 Opcode::MemoryOp { block_id: _, op } => CircuitSimulator::expr_witness(&op.operation)
207 .chain(CircuitSimulator::expr_witness(&op.index))
208 .chain(CircuitSimulator::expr_witness(&op.value))
209 .collect(),
210
211 Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
212 init.iter().copied().collect()
213 }
214 Opcode::BrilligCall { inputs, outputs, predicate, .. } => {
215 let mut witnesses = BTreeSet::new();
216 for i in inputs {
217 self.for_each_brillig_input_witness(i, |witness| {
218 witnesses.insert(witness);
219 });
220 }
221 witnesses.extend(CircuitSimulator::expr_witness(predicate));
222 for i in outputs {
223 self.for_each_brillig_output_witness(i, |witness| {
224 witnesses.insert(witness);
225 });
226 }
227 witnesses
228 }
229 Opcode::Call { id: _, inputs, outputs, predicate } => {
230 let mut witnesses: BTreeSet<Witness> = inputs.iter().copied().collect();
231 witnesses.extend(outputs);
232 witnesses.extend(CircuitSimulator::expr_witness(predicate));
233 witnesses
234 }
235 }
236 }
237
238 fn merge_expression(
241 target: &Expression<F>,
242 expr: &Expression<F>,
243 witness: Witness,
244 ) -> Option<Expression<F>> {
245 for m in &target.mul_terms {
247 if m.1 == witness || m.2 == witness {
248 return None;
249 }
250 }
251 for m in &expr.mul_terms {
252 if m.1 == witness || m.2 == witness {
253 return None;
254 }
255 }
256
257 for k in &target.linear_combinations {
258 if k.1 == witness {
259 for i in &expr.linear_combinations {
260 if i.1 == witness {
261 assert!(
262 i.0 != F::zero(),
263 "merge_expression: attempting to divide k.0 by F::zero"
264 );
265 let expr = target.add_mul(-(k.0 / i.0), expr);
266 let expr = GeneralOptimizer::optimize(expr);
267 return Some(expr);
268 }
269 }
270 }
271 }
272 None
273 }
274
275 fn get_opcode(&self, index: usize, circuit: &Circuit<F>) -> Option<Opcode<F>> {
280 if self.deleted_gates.contains(&index) {
281 return None;
282 }
283 self.modified_gates.get(&index).or(circuit.opcodes.get(index)).cloned()
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use crate::{
290 assert_circuit_snapshot,
291 compiler::{
292 CircuitSimulator,
293 optimizers::common_subexpression::merge_expressions::MergeExpressionsOptimizer,
294 },
295 };
296 use acir::{
297 AcirField, FieldElement,
298 circuit::Circuit,
299 native_types::{Expression, Witness},
300 };
301
302 fn merge_expressions(circuit: Circuit<FieldElement>) -> Circuit<FieldElement> {
303 assert!(CircuitSimulator::check_circuit(&circuit).is_none());
304 let mut merge_optimizer = MergeExpressionsOptimizer::new();
305 let acir_opcode_positions = vec![0; 20];
306 let (opcodes, _) =
307 merge_optimizer.eliminate_intermediate_variable(&circuit, acir_opcode_positions);
308 let mut optimized_circuit = circuit;
309 optimized_circuit.opcodes = opcodes;
310
311 assert!(CircuitSimulator::check_circuit(&optimized_circuit).is_none());
313 optimized_circuit
314 }
315
316 #[test]
317 fn merges_expressions() {
318 let src = "
319 private parameters: [w0]
320 public parameters: []
321 return values: [w2]
322 ASSERT 2*w1 = w0 + 5
323 ASSERT w2 = 4*w1 + 4
324 ";
325 let circuit = Circuit::from_str(src).unwrap();
326 let optimized_circuit = merge_expressions(circuit);
327 assert_circuit_snapshot!(optimized_circuit, @r"
328 private parameters: [w0]
329 public parameters: []
330 return values: [w2]
331 ASSERT w2 = 2*w0 + 14
332 ");
333 }
334
335 #[test]
336 fn does_not_eliminate_witnesses_returned_from_brillig() {
337 let src = "
338 private parameters: [w0]
339 public parameters: []
340 return values: []
341 BRILLIG CALL func: 0, predicate: 1, inputs: [], outputs: [w1]
342 ASSERT 2*w0 + 3*w1 + w2 + 1 = 0
343 ASSERT 2*w0 + 2*w1 + w5 + 1 = 0
344 ";
345 let circuit = Circuit::from_str(src).unwrap();
346 let optimized_circuit = merge_expressions(circuit.clone());
347 assert_eq!(circuit, optimized_circuit);
348 }
349
350 #[test]
351 fn does_not_eliminate_witnesses_returned_from_circuit() {
352 let src = "
353 private parameters: [w0]
354 public parameters: []
355 return values: [w1, w2]
356 ASSERT -w0*w0 + w1 = 0
357 ASSERT -w1 + w2 = 0
358 ";
359 let circuit = Circuit::from_str(src).unwrap();
360 let optimized_circuit = merge_expressions(circuit.clone());
361 assert_eq!(circuit, optimized_circuit);
362 }
363
364 #[test]
365 fn does_not_attempt_to_merge_into_previous_opcodes() {
366 let src = "
367 private parameters: [w0, w1]
368 public parameters: []
369 return values: []
370 ASSERT w0*w0 - w4 = 0
371 ASSERT w0*w1 + w5 = 0
372 ASSERT -w2 + w4 + w5 = 0
373 ASSERT w2 - w3 + w4 + w5 = 0
374 BLACKBOX::RANGE input: w3, bits: 32
375 ";
376 let circuit = Circuit::from_str(src).unwrap();
377
378 let optimized_circuit = merge_expressions(circuit);
379 assert_circuit_snapshot!(optimized_circuit, @r"
380 private parameters: [w0, w1]
381 public parameters: []
382 return values: []
383 ASSERT w5 = -w0*w1
384 ASSERT w3 = 2*w0*w0 + 2*w5
385 BLACKBOX::RANGE input: w3, bits: 32
386 ");
387 }
388
389 #[test]
390 fn takes_blackbox_opcode_outputs_into_account() {
391 let src = "
396 private parameters: [w0, w1]
397 public parameters: []
398 return values: [w2]
399 BRILLIG CALL func: 0, predicate: 1, inputs: [], outputs: [w3]
400 BLACKBOX::AND lhs: w0, rhs: w1, output: w4, bits: 8
401 ASSERT w3 - w4 = 0
402 ASSERT -w2 + w4 = 0
403 ";
404 let circuit = Circuit::from_str(src).unwrap();
405 let optimized_circuit = merge_expressions(circuit.clone());
406 assert_eq!(circuit, optimized_circuit);
407 }
408
409 #[test]
410 #[should_panic(expected = "merge_expression: attempting to divide k.0 by F::zero")]
411 fn merge_expression_on_zero_linear_combination_panics() {
412 let opcode_a = Expression {
413 mul_terms: vec![],
414 linear_combinations: vec![(FieldElement::one(), Witness(0))],
415 q_c: FieldElement::zero(),
416 };
417 let opcode_b = Expression {
418 mul_terms: vec![],
419 linear_combinations: vec![(FieldElement::zero(), Witness(0))],
420 q_c: FieldElement::zero(),
421 };
422 assert_eq!(
423 MergeExpressionsOptimizer::merge_expression(&opcode_a, &opcode_b, Witness(0),),
424 Some(opcode_a)
425 );
426 }
427
428 #[test]
429 fn does_not_eliminate_witnesses_used_in_brillig_call_predicates() {
430 let src = "
431 private parameters: [w2]
432 public parameters: [w0, w1]
433 return values: [w3]
434 BLACKBOX::RANGE input: w0, bits: 1
435 BLACKBOX::RANGE input: w1, bits: 1
436 BLACKBOX::RANGE input: w2, bits: 1
437 ASSERT w4 = w0*w1
438 ASSERT w5 = -w2 + 1
439 BRILLIG CALL func: 0, predicate: w4*w5, inputs: [w2], outputs: [w6]
440 ASSERT w3 = -w5 + 1
441 ";
442 let circuit = Circuit::from_str(src).unwrap();
443 let optimized_circuit = merge_expressions(circuit.clone());
444 assert_eq!(circuit, optimized_circuit);
445 }
446}