1use std::collections::BTreeMap;
31
32use acir::{
33 AcirField,
34 circuit::{
35 Circuit, Opcode,
36 brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
37 opcodes::{BlackBoxFuncCall, FunctionInput, MemOp},
38 },
39 native_types::{Expression, Witness},
40};
41use indexmap::IndexMap;
42
43mod csat;
44mod merge_expressions;
45
46use csat::CSatTransformer;
47use merge_expressions::MergeExpressionsOptimizer;
48
49use std::hash::BuildHasher;
50use tracing::info;
51
52use super::RangeOptimizer;
53
54const DEFAULT_MAX_TRANSFORMER_PASSES: usize = 3;
56
57#[tracing::instrument(level = "trace", name = "transform_acir", skip(acir, acir_opcode_positions))]
67pub(super) fn transform_internal<F: AcirField>(
68 mut acir: Circuit<F>,
69 mut acir_opcode_positions: Vec<usize>,
70 brillig_side_effects: &BTreeMap<BrilligFunctionId, bool>,
71 max_transformer_passes_or_default: Option<usize>,
72) -> (Circuit<F>, Vec<usize>, bool) {
73 if acir.opcodes.len() == 1 && matches!(acir.opcodes[0], Opcode::BrilligCall { .. }) {
74 info!("Program is fully unconstrained, skipping transformation pass");
75 return (acir, acir_opcode_positions, true);
76 }
77
78 let mut prev_opcodes_hash = rustc_hash::FxBuildHasher.hash_one(&acir.opcodes);
80
81 let mut opcodes_hash_stabilized = false;
83
84 let max_transformer_passes =
85 max_transformer_passes_or_default.unwrap_or(DEFAULT_MAX_TRANSFORMER_PASSES);
86
87 for _ in 0..max_transformer_passes {
90 info!("Number of opcodes {}", acir.opcodes.len());
91 let (new_acir, new_acir_opcode_positions) =
92 transform_internal_once(acir, acir_opcode_positions, brillig_side_effects);
93
94 acir = new_acir;
95 acir_opcode_positions = new_acir_opcode_positions;
96
97 let new_opcodes_hash = rustc_hash::FxBuildHasher.hash_one(&acir.opcodes);
98
99 if new_opcodes_hash == prev_opcodes_hash {
100 opcodes_hash_stabilized = true;
101 break;
102 }
103 prev_opcodes_hash = new_opcodes_hash;
104 }
105
106 acir.current_witness_index = max_witness(&acir).witness_index();
109
110 (acir, acir_opcode_positions, opcodes_hash_stabilized)
111}
112
113#[tracing::instrument(
123 level = "trace",
124 name = "transform_acir_once",
125 skip(acir, acir_opcode_positions)
126)]
127fn transform_internal_once<F: AcirField>(
128 mut acir: Circuit<F>,
129 acir_opcode_positions: Vec<usize>,
130 brillig_side_effects: &BTreeMap<BrilligFunctionId, bool>,
131) -> (Circuit<F>, Vec<usize>) {
132 let mut transformer = CSatTransformer::new(4);
137 for value in acir.circuit_arguments() {
138 transformer.mark_solvable(value);
139 }
140
141 let mut new_acir_opcode_positions: Vec<usize> = Vec::with_capacity(acir_opcode_positions.len());
142 let mut transformed_opcodes = Vec::new();
145
146 let mut next_witness_index = acir.current_witness_index + 1;
147 let mut intermediate_variables: IndexMap<Expression<F>, (F, Witness)> = IndexMap::new();
150 for (index, opcode) in acir.opcodes.into_iter().enumerate() {
151 match opcode {
152 Opcode::AssertZero(arith_expr) => {
153 let len = intermediate_variables.len();
154
155 let arith_expr = transformer.transform(
156 arith_expr,
157 &mut intermediate_variables,
158 &mut next_witness_index,
159 );
160
161 let mut new_opcodes = Vec::new();
162 for (g, (norm, w)) in intermediate_variables.iter().skip(len) {
163 let mut intermediate_opcode = g * *norm;
165 intermediate_opcode.linear_combinations.push((-F::one(), *w));
167 intermediate_opcode.sort();
168 new_opcodes.push(intermediate_opcode);
169 }
170 new_opcodes.push(arith_expr);
171 for opcode in new_opcodes {
172 new_acir_opcode_positions.push(acir_opcode_positions[index]);
173 transformed_opcodes.push(Opcode::AssertZero(opcode));
174 }
175 }
176 Opcode::BlackBoxFuncCall(ref func) => {
177 for witness in func.get_outputs_vec() {
178 transformer.mark_solvable(witness);
179 }
180
181 new_acir_opcode_positions.push(acir_opcode_positions[index]);
182 transformed_opcodes.push(opcode);
183 }
184 Opcode::MemoryInit { .. } => {
185 new_acir_opcode_positions.push(acir_opcode_positions[index]);
187 transformed_opcodes.push(opcode);
188 }
189 Opcode::MemoryOp { ref op, .. } => {
190 for (_, witness1, witness2) in &op.value.mul_terms {
191 transformer.mark_solvable(*witness1);
192 transformer.mark_solvable(*witness2);
193 }
194 for (_, witness) in &op.value.linear_combinations {
195 transformer.mark_solvable(*witness);
196 }
197 new_acir_opcode_positions.push(acir_opcode_positions[index]);
198 transformed_opcodes.push(opcode);
199 }
200 Opcode::BrilligCall { ref outputs, .. } => {
201 for output in outputs {
202 match output {
203 BrilligOutputs::Simple(witness) => transformer.mark_solvable(*witness),
204 BrilligOutputs::Array(witnesses) => {
205 for witness in witnesses {
206 transformer.mark_solvable(*witness);
207 }
208 }
209 }
210 }
211
212 new_acir_opcode_positions.push(acir_opcode_positions[index]);
213 transformed_opcodes.push(opcode);
214 }
215 Opcode::Call { ref outputs, .. } => {
216 for witness in outputs {
217 transformer.mark_solvable(*witness);
218 }
219
220 new_acir_opcode_positions.push(acir_opcode_positions[index]);
223 transformed_opcodes.push(opcode);
224 }
225 }
226 }
227
228 let current_witness_index = next_witness_index - 1;
229
230 acir = Circuit {
231 current_witness_index,
232 opcodes: transformed_opcodes,
233 ..acir
235 };
236
237 let mut merge_optimizer = MergeExpressionsOptimizer::new();
239
240 let (opcodes, new_acir_opcode_positions) =
241 merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);
242
243 acir = Circuit {
245 opcodes,
246 ..acir
248 };
249
250 let range_optimizer = RangeOptimizer::new(acir, brillig_side_effects);
254 let (acir, new_acir_opcode_positions) =
255 range_optimizer.replace_redundant_ranges(new_acir_opcode_positions);
256
257 (acir, new_acir_opcode_positions)
258}
259
260fn max_witness<F: AcirField>(circuit: &Circuit<F>) -> Witness {
262 let mut witnesses = WitnessFolder::new(Witness::default(), |state, witness| {
263 *state = witness.max(*state);
264 });
265 witnesses.fold_circuit(circuit);
266 witnesses.into_state()
267}
268
269struct WitnessFolder<S, A> {
271 state: S,
272 accumulate: A,
273}
274
275impl<S, A> WitnessFolder<S, A>
276where
277 A: Fn(&mut S, Witness),
278{
279 fn new(init: S, accumulate: A) -> Self {
281 Self { state: init, accumulate }
282 }
283
284 fn into_state(self) -> S {
286 self.state
287 }
288
289 fn fold_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) {
291 self.fold_many(circuit.private_parameters.iter());
292 self.fold_many(circuit.public_parameters.0.iter());
293 self.fold_many(circuit.return_values.0.iter());
294 for opcode in &circuit.opcodes {
295 self.fold_opcode(opcode);
296 }
297 }
298
299 fn fold(&mut self, witness: Witness) {
301 (self.accumulate)(&mut self.state, witness);
302 }
303
304 fn fold_many<'w, I: Iterator<Item = &'w Witness>>(&mut self, witnesses: I) {
306 for witness in witnesses {
307 self.fold(*witness);
308 }
309 }
310
311 fn fold_opcode<F: AcirField>(&mut self, opcode: &Opcode<F>) {
313 match opcode {
314 Opcode::AssertZero(expr) => {
315 self.fold_expr(expr);
316 }
317 Opcode::BlackBoxFuncCall(call) => self.fold_blackbox(call),
318 Opcode::MemoryOp { block_id: _, op } => {
319 let MemOp { operation, index, value } = op;
320 self.fold_expr(operation);
321 self.fold_expr(index);
322 self.fold_expr(value);
323 }
324 Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
325 for witness in init {
326 self.fold(*witness);
327 }
328 }
329 Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
332 self.fold_expr(predicate);
333 self.fold_brillig_inputs(inputs);
334 self.fold_brillig_outputs(outputs);
335 }
336 Opcode::Call { id: _, inputs, outputs, predicate } => {
337 self.fold_expr(predicate);
338 self.fold_many(inputs.iter());
339 self.fold_many(outputs.iter());
340 }
341 }
342 }
343
344 fn fold_expr<F: AcirField>(&mut self, expr: &Expression<F>) {
345 for i in &expr.mul_terms {
346 self.fold(i.1);
347 self.fold(i.2);
348 }
349 for i in &expr.linear_combinations {
350 self.fold(i.1);
351 }
352 }
353
354 fn fold_brillig_inputs<F: AcirField>(&mut self, inputs: &[BrilligInputs<F>]) {
355 for input in inputs {
356 match input {
357 BrilligInputs::Single(expr) => {
358 self.fold_expr(expr);
359 }
360 BrilligInputs::Array(exprs) => {
361 for expr in exprs {
362 self.fold_expr(expr);
363 }
364 }
365 BrilligInputs::MemoryArray(_) => {}
366 }
367 }
368 }
369
370 fn fold_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) {
371 for output in outputs {
372 match output {
373 BrilligOutputs::Simple(witness) => {
374 self.fold(*witness);
375 }
376 BrilligOutputs::Array(witnesses) => self.fold_many(witnesses.iter()),
377 }
378 }
379 }
380
381 fn fold_blackbox<F: AcirField>(&mut self, call: &BlackBoxFuncCall<F>) {
382 match call {
383 BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => {
384 self.fold_inputs(inputs.as_slice());
385 self.fold_inputs(iv.as_slice());
386 self.fold_inputs(key.as_slice());
387 self.fold_many(outputs.iter());
388 }
389 BlackBoxFuncCall::AND { lhs, rhs, output, .. } => {
390 self.fold_input(lhs);
391 self.fold_input(rhs);
392 self.fold(*output);
393 }
394 BlackBoxFuncCall::XOR { lhs, rhs, output, .. } => {
395 self.fold_input(lhs);
396 self.fold_input(rhs);
397 self.fold(*output);
398 }
399 BlackBoxFuncCall::RANGE { input, .. } => {
400 self.fold_input(input);
401 }
402 BlackBoxFuncCall::Blake2s { inputs, outputs } => {
403 self.fold_inputs(inputs.as_slice());
404 self.fold_many(outputs.iter());
405 }
406 BlackBoxFuncCall::Blake3 { inputs, outputs } => {
407 self.fold_inputs(inputs.as_slice());
408 self.fold_many(outputs.iter());
409 }
410 BlackBoxFuncCall::EcdsaSecp256k1 {
411 public_key_x,
412 public_key_y,
413 signature,
414 hashed_message,
415 output,
416 predicate,
417 } => {
418 self.fold_inputs(public_key_x.as_slice());
419 self.fold_inputs(public_key_y.as_slice());
420 self.fold_inputs(signature.as_slice());
421 self.fold_inputs(hashed_message.as_slice());
422 self.fold(*output);
423 self.fold_input(predicate);
424 }
425 BlackBoxFuncCall::EcdsaSecp256r1 {
426 public_key_x,
427 public_key_y,
428 signature,
429 hashed_message,
430 output,
431 predicate,
432 } => {
433 self.fold_inputs(public_key_x.as_slice());
434 self.fold_inputs(public_key_y.as_slice());
435 self.fold_inputs(signature.as_slice());
436 self.fold_inputs(hashed_message.as_slice());
437 self.fold(*output);
438 self.fold_input(predicate);
439 }
440 BlackBoxFuncCall::MultiScalarMul { points, scalars, predicate, outputs } => {
441 self.fold_inputs(points.as_slice());
442 self.fold_inputs(scalars.as_slice());
443 self.fold_input(predicate);
444 let (x, y, i) = outputs;
445 self.fold(*x);
446 self.fold(*y);
447 self.fold(*i);
448 }
449 BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, predicate, outputs } => {
450 self.fold_inputs(input1.as_slice());
451 self.fold_inputs(input2.as_slice());
452 self.fold_input(predicate);
453 let (x, y, i) = outputs;
454 self.fold(*x);
455 self.fold(*y);
456 self.fold(*i);
457 }
458 BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => {
459 self.fold_inputs(inputs.as_slice());
460 self.fold_many(outputs.iter());
461 }
462 BlackBoxFuncCall::RecursiveAggregation {
463 verification_key,
464 proof,
465 public_inputs,
466 key_hash,
467 proof_type: _,
468 predicate,
469 } => {
470 self.fold_inputs(verification_key.as_slice());
471 self.fold_inputs(proof.as_slice());
472 self.fold_inputs(public_inputs.as_slice());
473 self.fold_input(key_hash);
474 self.fold_input(predicate);
475 }
476 BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs } => {
477 self.fold_inputs(inputs.as_slice());
478 self.fold_many(outputs.iter());
479 }
480 BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => {
481 self.fold_inputs(inputs.as_slice());
482 self.fold_inputs(hash_values.as_slice());
483 self.fold_many(outputs.iter());
484 }
485 }
486 }
487
488 fn fold_inputs<F: AcirField>(&mut self, inputs: &[FunctionInput<F>]) {
489 for input in inputs {
490 self.fold_input(input);
491 }
492 }
493
494 fn fold_input<F: AcirField>(&mut self, input: &FunctionInput<F>) {
495 if let FunctionInput::Witness(witness) = input {
496 self.fold(*witness);
497 }
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::transform_internal;
504 use crate::compiler::CircuitSimulator;
505 use acir::circuit::{Circuit, brillig::BrilligFunctionId};
506 use std::collections::BTreeMap;
507
508 #[test]
509 fn test_max_transformer_passes() {
510 let formatted_acir = r#"private parameters: [w0]
511 public parameters: []
512 return values: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25, w26, w27, w28, w29, w30, w31]
513 BRILLIG CALL func: 0, predicate: 1, inputs: [w0, 31, 256], outputs: [w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50, w51, w52, w53, w54, w55, w56, w57, w58, w59, w60, w61, w62]
514 BLACKBOX::RANGE input: w35, bits: 8
515 BLACKBOX::RANGE input: w36, bits: 8
516 BLACKBOX::RANGE input: w37, bits: 8
517 BLACKBOX::RANGE input: w38, bits: 8
518 BLACKBOX::RANGE input: w39, bits: 8
519 BLACKBOX::RANGE input: w40, bits: 8
520 BLACKBOX::RANGE input: w41, bits: 8
521 BLACKBOX::RANGE input: w42, bits: 8
522 BLACKBOX::RANGE input: w43, bits: 8
523 BLACKBOX::RANGE input: w44, bits: 8
524 BLACKBOX::RANGE input: w45, bits: 8
525 BLACKBOX::RANGE input: w46, bits: 8
526 BLACKBOX::RANGE input: w47, bits: 8
527 BLACKBOX::RANGE input: w48, bits: 8
528 BLACKBOX::RANGE input: w49, bits: 8
529 BLACKBOX::RANGE input: w50, bits: 8
530 BLACKBOX::RANGE input: w51, bits: 8
531 BLACKBOX::RANGE input: w52, bits: 8
532 BLACKBOX::RANGE input: w53, bits: 8
533 BLACKBOX::RANGE input: w54, bits: 8
534 BLACKBOX::RANGE input: w55, bits: 8
535 BLACKBOX::RANGE input: w56, bits: 8
536 BLACKBOX::RANGE input: w57, bits: 8
537 BLACKBOX::RANGE input: w58, bits: 8
538 BLACKBOX::RANGE input: w59, bits: 8
539 BLACKBOX::RANGE input: w60, bits: 8
540 BLACKBOX::RANGE input: w61, bits: 8
541 BLACKBOX::RANGE input: w62, bits: 8
542 ASSERT w32 = w0 - 256*w33 - 65536*w34 - 16777216*w35 - 4294967296*w36 - 1099511627776*w37 - 281474976710656*w38 - 72057594037927936*w39 - 18446744073709551616*w40 - 4722366482869645213696*w41 - 1208925819614629174706176*w42 - 309485009821345068724781056*w43 - 79228162514264337593543950336*w44 - 20282409603651670423947251286016*w45 - 5192296858534827628530496329220096*w46 - 1329227995784915872903807060280344576*w47 - 340282366920938463463374607431768211456*w48 - 87112285931760246646623899502532662132736*w49 - 22300745198530623141535718272648361505980416*w50 - 5708990770823839524233143877797980545530986496*w51 - 1461501637330902918203684832716283019655932542976*w52 - 374144419156711147060143317175368453031918731001856*w53 - 95780971304118053647396689196894323976171195136475136*w54 - 24519928653854221733733552434404946937899825954937634816*w55 - 6277101735386680763835789423207666416102355444464034512896*w56 - 1606938044258990275541962092341162602522202993782792835301376*w57 - 411376139330301510538742295639337626245683966408394965837152256*w58 - 105312291668557186697918027683670432318895095400549111254310977536*w59 - 26959946667150639794667015087019630673637144422540572481103610249216*w60 - 6901746346790563787434755862277025452451108972170386555162524223799296*w61 - 1766847064778384329583297500742918515827483896875618958121606201292619776*w62
543 ASSERT w32 = 60
544 ASSERT w33 = 33
545 ASSERT w34 = 31
546 ASSERT w0 = 16777216*w35 + 4294967296*w36 + 1099511627776*w37 + 281474976710656*w38 + 72057594037927936*w39 + 18446744073709551616*w40 + 4722366482869645213696*w41 + 1208925819614629174706176*w42 + 309485009821345068724781056*w43 + 79228162514264337593543950336*w44 + 20282409603651670423947251286016*w45 + 5192296858534827628530496329220096*w46 + 1329227995784915872903807060280344576*w47 + 340282366920938463463374607431768211456*w48 + 87112285931760246646623899502532662132736*w49 + 22300745198530623141535718272648361505980416*w50 + 5708990770823839524233143877797980545530986496*w51 + 1461501637330902918203684832716283019655932542976*w52 + 374144419156711147060143317175368453031918731001856*w53 + 95780971304118053647396689196894323976171195136475136*w54 + 24519928653854221733733552434404946937899825954937634816*w55 + 6277101735386680763835789423207666416102355444464034512896*w56 + 1606938044258990275541962092341162602522202993782792835301376*w57 + 411376139330301510538742295639337626245683966408394965837152256*w58 + 105312291668557186697918027683670432318895095400549111254310977536*w59 + 26959946667150639794667015087019630673637144422540572481103610249216*w60 + 6901746346790563787434755862277025452451108972170386555162524223799296*w61 + 1766847064778384329583297500742918515827483896875618958121606201292619776*w62 + 2040124
547 ASSERT w62 = w1
548 ASSERT w61 = w2
549 ASSERT w60 = w3
550 ASSERT w59 = w4
551 ASSERT w58 = w5
552 ASSERT w57 = w6
553 ASSERT w56 = w7
554 ASSERT w55 = w8
555 ASSERT w54 = w9
556 ASSERT w53 = w10
557 ASSERT w52 = w11
558 ASSERT w51 = w12
559 ASSERT w50 = w13
560 ASSERT w49 = w14
561 ASSERT w48 = w15
562 ASSERT w47 = w16
563 ASSERT w46 = w17
564 ASSERT w45 = w18
565 ASSERT w44 = w19
566 ASSERT w43 = w20
567 ASSERT w42 = w21
568 ASSERT w41 = w22
569 ASSERT w40 = w23
570 ASSERT w39 = w24
571 ASSERT w38 = w25
572 ASSERT w37 = w26
573 ASSERT w36 = w27
574 ASSERT w35 = w28
575 ASSERT w29 = 31
576 ASSERT w30 = 33
577 ASSERT w31 = 60
578 "#;
579
580 let acir = Circuit::from_str(formatted_acir).unwrap();
581 assert!(CircuitSimulator::check_circuit(&acir).is_none());
582
583 let acir_opcode_positions = vec![
584 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
585 24, 25, 26, 27, 28, 29, 29, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
586 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
587 ];
588 let mut brillig_side_effects = BTreeMap::new();
589 brillig_side_effects.insert(BrilligFunctionId(0), false);
590
591 let (_, _, opcodes_hash_stabilized) =
592 transform_internal(acir, acir_opcode_positions, &brillig_side_effects, None);
593 assert!(!opcodes_hash_stabilized);
594 }
595}