acvm/compiler/transformers/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
/// This module applies backend specific transformation to a [`Circuit`].
///
/// ## CSAT: transforms AssertZero opcodes into  AssertZero opcodes having the required width.
///
/// For instance, if the width is 4, the AssertZero opcode x1 + x2 + x3 + x4 + x5 - y = 0 will be transformed using 2 intermediate variables (z1,z2):
/// x1 + x2 + x3 = z1
/// x4 + x5 = z2
/// z1 + z2 - y = 0
/// If x1,..x5 are inputs to the program, they are taggeg as 'solvable', and would be used to compute the value of y.
/// If we generate the intermediate variable x4 + x5 - y = z3, we get an unsolvable circuit because this AssertZero opcode will have two unkwnon values: y and z3
/// So the CSAT transformation keep track of which witness would be solved for each opcode in order to only generate solvable intermediat variables.
///
/// ## eliminate intermediate variables
/// The 'eliminate intermediate variables' pass will remove any intermediate variables (for instance created by the previous transformation)
/// that are used in exactly two AssertZero opcodes.
/// This results in arithmetic opcodes having linear combinations of potentially large width.
/// For instance if the intermediate variable is z1 is only used in y:
/// z1 = x1 + x2 +x3
/// y = z1 + x4
/// We remove it, undoing the work done during the CSAT transformation: y = x1 + x2 + x3 + x4
///
/// We do this because the backend is expected to handle linear combinations of 'unbounded width' in a more efficient way
/// than the 'CSAT transformation'.
/// However, it is worth to compute intermediate variables if they are used in more than one other opcode.
///
/// ## redundant_range
/// The 'range optimization' pass, from the optimizers module will remove any redundant range opcodes again.
use std::collections::BTreeMap;

use acir::{
    AcirField,
    circuit::{
        Circuit, ExpressionWidth, Opcode,
        brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
        opcodes::{BlackBoxFuncCall, FunctionInput, MemOp},
    },
    native_types::{Expression, Witness},
};
use indexmap::IndexMap;

mod csat;

pub(crate) use csat::CSatTransformer;
pub use csat::MIN_EXPRESSION_WIDTH;
use std::hash::BuildHasher;
use tracing::info;

use super::{
    AcirTransformationMap,
    optimizers::{MergeExpressionsOptimizer, RangeOptimizer},
    transform_assert_messages,
};

/// We need multiple passes to stabilize the output.
/// The value was determined by running tests.
const MAX_TRANSFORMER_PASSES: usize = 3;

/// Applies backend specific optimizations to a [`Circuit`].
pub fn transform<F: AcirField>(
    acir: Circuit<F>,
    expression_width: ExpressionWidth,
    brillig_side_effects: &BTreeMap<BrilligFunctionId, bool>,
) -> (Circuit<F>, AcirTransformationMap) {
    // Track original acir opcode positions throughout the transformation passes of the compilation
    // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert)
    let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect();

    let (mut acir, acir_opcode_positions) =
        transform_internal(acir, expression_width, acir_opcode_positions, brillig_side_effects);

    let transformation_map = AcirTransformationMap::new(&acir_opcode_positions);

    acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map);

    (acir, transformation_map)
}

/// Applies backend specific optimizations to a [`Circuit`].
///
/// Accepts an injected `acir_opcode_positions` to allow transformations to be applied directly after optimizations.
///
/// Does multiple passes until the output stabilizes.
#[tracing::instrument(level = "trace", name = "transform_acir", skip(acir, acir_opcode_positions))]
pub(super) fn transform_internal<F: AcirField>(
    mut acir: Circuit<F>,
    expression_width: ExpressionWidth,
    mut acir_opcode_positions: Vec<usize>,
    brillig_side_effects: &BTreeMap<BrilligFunctionId, bool>,
) -> (Circuit<F>, Vec<usize>) {
    if acir.opcodes.len() == 1 && matches!(acir.opcodes[0], Opcode::BrilligCall { .. }) {
        info!("Program is fully unconstrained, skipping transformation pass");
        return (acir, acir_opcode_positions);
    }

    // Allow multiple passes until we have stable output.
    let mut prev_opcodes_hash = rustc_hash::FxBuildHasher.hash_one(&acir.opcodes);

    // For most test programs it would be enough to loop here, but some of them
    // don't stabilize unless we also repeat the backend agnostic optimizations.
    for _ in 0..MAX_TRANSFORMER_PASSES {
        info!("Number of opcodes {}", acir.opcodes.len());
        let (new_acir, new_acir_opcode_positions) = transform_internal_once(
            acir,
            expression_width,
            acir_opcode_positions,
            brillig_side_effects,
        );

        acir = new_acir;
        acir_opcode_positions = new_acir_opcode_positions;

        let new_opcodes_hash = rustc_hash::FxBuildHasher.hash_one(&acir.opcodes);

        if new_opcodes_hash == prev_opcodes_hash {
            break;
        }
        prev_opcodes_hash = new_opcodes_hash;
    }
    // After the elimination of intermediate variables the `current_witness_index` is potentially higher than it needs to be,
    // which would cause gaps if we ran the optimization a second time, making it look like new variables were added.
    acir.current_witness_index = max_witness(&acir).witness_index();

    (acir, acir_opcode_positions)
}

/// Accepts an injected `acir_opcode_positions` to allow transformations to be applied directly after optimizations.
///
/// If the width is unbounded, it does nothing.
/// If it is bounded, it first performs the 'CSAT transformation' in one pass, by creating intermediate variables when necessary.
/// Then it performs `eliminate_intermediate_variable()` which (re-)combine intermediate variables used only once.
/// It concludes with a round of `replace_redundant_ranges()` which removes range checks made redundant by the previous pass.
#[tracing::instrument(
    level = "trace",
    name = "transform_acir_once",
    skip(acir, acir_opcode_positions)
)]
fn transform_internal_once<F: AcirField>(
    mut acir: Circuit<F>,
    expression_width: ExpressionWidth,
    acir_opcode_positions: Vec<usize>,
    brillig_side_effects: &BTreeMap<BrilligFunctionId, bool>,
) -> (Circuit<F>, Vec<usize>) {
    // If the expression width is unbounded, we don't need to do anything.
    let mut transformer = match &expression_width {
        ExpressionWidth::Unbounded => {
            return (acir, acir_opcode_positions);
        }
        ExpressionWidth::Bounded { width } => {
            let mut csat = CSatTransformer::new(*width);
            for value in acir.circuit_arguments() {
                csat.mark_solvable(value);
            }
            csat
        }
    };

    // 1. CSAT transformation
    // Process each opcode in the circuit by marking the solvable witnesses and transforming the AssertZero opcodes
    // to the required width by creating intermediate variables.
    // Knowing if a witness is solvable avoids creating un-solvable intermediate variables.

    let mut new_acir_opcode_positions: Vec<usize> = Vec::with_capacity(acir_opcode_positions.len());
    // Optimize the assert-zero gates by reducing them into the correct width and
    // creating intermediate variables when necessary
    let mut transformed_opcodes = Vec::new();

    let mut next_witness_index = acir.current_witness_index + 1;
    // maps a normalized expression to the intermediate variable which represents the expression, along with its 'norm'
    // the 'norm' is simply the value of the first non zero coefficient in the expression, taken from the linear terms, or quadratic terms if there is none.
    let mut intermediate_variables: IndexMap<Expression<F>, (F, Witness)> = IndexMap::new();
    for (index, opcode) in acir.opcodes.into_iter().enumerate() {
        match opcode {
            Opcode::AssertZero(arith_expr) => {
                let len = intermediate_variables.len();

                let arith_expr = transformer.transform(
                    arith_expr,
                    &mut intermediate_variables,
                    &mut next_witness_index,
                );

                let mut new_opcodes = Vec::new();
                for (g, (norm, w)) in intermediate_variables.iter().skip(len) {
                    // de-normalize
                    let mut intermediate_opcode = g * *norm;
                    // constrain the intermediate opcode to the intermediate variable
                    intermediate_opcode.linear_combinations.push((-F::one(), *w));
                    intermediate_opcode.sort();
                    new_opcodes.push(intermediate_opcode);
                }
                new_opcodes.push(arith_expr);
                for opcode in new_opcodes {
                    new_acir_opcode_positions.push(acir_opcode_positions[index]);
                    transformed_opcodes.push(Opcode::AssertZero(opcode));
                }
            }
            Opcode::BlackBoxFuncCall(ref func) => {
                for witness in func.get_outputs_vec() {
                    transformer.mark_solvable(witness);
                }

                new_acir_opcode_positions.push(acir_opcode_positions[index]);
                transformed_opcodes.push(opcode);
            }
            Opcode::MemoryInit { .. } => {
                // `MemoryInit` does not write values to the `WitnessMap`
                new_acir_opcode_positions.push(acir_opcode_positions[index]);
                transformed_opcodes.push(opcode);
            }
            Opcode::MemoryOp { ref op, .. } => {
                for (_, witness1, witness2) in &op.value.mul_terms {
                    transformer.mark_solvable(*witness1);
                    transformer.mark_solvable(*witness2);
                }
                for (_, witness) in &op.value.linear_combinations {
                    transformer.mark_solvable(*witness);
                }
                new_acir_opcode_positions.push(acir_opcode_positions[index]);
                transformed_opcodes.push(opcode);
            }
            Opcode::BrilligCall { ref outputs, .. } => {
                for output in outputs {
                    match output {
                        BrilligOutputs::Simple(w) => transformer.mark_solvable(*w),
                        BrilligOutputs::Array(v) => {
                            for witness in v {
                                transformer.mark_solvable(*witness);
                            }
                        }
                    }
                }

                new_acir_opcode_positions.push(acir_opcode_positions[index]);
                transformed_opcodes.push(opcode);
            }
            Opcode::Call { ref outputs, .. } => {
                for witness in outputs {
                    transformer.mark_solvable(*witness);
                }

                // `Call` does not write values to the `WitnessMap`
                // A separate ACIR function should have its own respective `WitnessMap`
                new_acir_opcode_positions.push(acir_opcode_positions[index]);
                transformed_opcodes.push(opcode);
            }
        }
    }

    let current_witness_index = next_witness_index - 1;

    acir = Circuit {
        current_witness_index,
        opcodes: transformed_opcodes,
        // The transformer does not add new public inputs
        ..acir
    };

    // 2. Eliminate intermediate variables, when they are used in exactly two arithmetic opcodes.
    let mut merge_optimizer = MergeExpressionsOptimizer::new();

    let (opcodes, new_acir_opcode_positions) =
        merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);

    // n.b. if we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
    acir = Circuit {
        opcodes,
        // The optimizer does not add new public inputs
        ..acir
    };

    // 3. Remove redundant range constraints.
    // The `MergeOptimizer` can merge two witnesses which have range opcodes applied to them
    // so we run the `RangeOptimizer` afterwards to clear these up.
    let range_optimizer = RangeOptimizer::new(acir, brillig_side_effects);
    let (acir, new_acir_opcode_positions) =
        range_optimizer.replace_redundant_ranges(new_acir_opcode_positions);

    (acir, new_acir_opcode_positions)
}

/// Find the witness with the highest ID in the circuit.
fn max_witness<F: AcirField>(circuit: &Circuit<F>) -> Witness {
    let mut witnesses = WitnessFolder::new(Witness::default(), |state, witness| {
        *state = witness.max(*state);
    });
    witnesses.fold_circuit(circuit);
    witnesses.into_state()
}

/// Fold all witnesses in a circuit.
struct WitnessFolder<S, A> {
    state: S,
    accumulate: A,
}

impl<S, A> WitnessFolder<S, A>
where
    A: Fn(&mut S, Witness),
{
    /// Create the folder with some initial state and an accumulator function.
    fn new(init: S, accumulate: A) -> Self {
        Self { state: init, accumulate }
    }

    /// Take the accumulated state.
    fn into_state(self) -> S {
        self.state
    }

    /// Add all witnesses from the circuit.
    fn fold_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) {
        self.fold_many(circuit.private_parameters.iter());
        self.fold_many(circuit.public_parameters.0.iter());
        self.fold_many(circuit.return_values.0.iter());
        for opcode in &circuit.opcodes {
            self.fold_opcode(opcode);
        }
    }

    /// Fold a witness into the state.
    fn fold(&mut self, witness: Witness) {
        (self.accumulate)(&mut self.state, witness);
    }

    /// Fold many witnesses into the state.
    fn fold_many<'w, I: Iterator<Item = &'w Witness>>(&mut self, witnesses: I) {
        for w in witnesses {
            self.fold(*w);
        }
    }

    /// Add witnesses from the opcode.
    fn fold_opcode<F: AcirField>(&mut self, opcode: &Opcode<F>) {
        match opcode {
            Opcode::AssertZero(expr) => {
                self.fold_expr(expr);
            }
            Opcode::BlackBoxFuncCall(call) => self.fold_blackbox(call),
            Opcode::MemoryOp { block_id: _, op } => {
                let MemOp { operation, index, value } = op;
                self.fold_expr(operation);
                self.fold_expr(index);
                self.fold_expr(value);
            }
            Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
                for w in init {
                    self.fold(*w);
                }
            }
            // We keep the display for a BrilligCall and circuit Call separate as they
            // are distinct in their functionality and we should maintain this separation for debugging.
            Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
                if let Some(pred) = predicate {
                    self.fold_expr(pred);
                }
                self.fold_brillig_inputs(inputs);
                self.fold_brillig_outputs(outputs);
            }
            Opcode::Call { id: _, inputs, outputs, predicate } => {
                if let Some(pred) = predicate {
                    self.fold_expr(pred);
                }
                self.fold_many(inputs.iter());
                self.fold_many(outputs.iter());
            }
        }
    }

    fn fold_expr<F: AcirField>(&mut self, expr: &Expression<F>) {
        for i in &expr.mul_terms {
            self.fold(i.1);
            self.fold(i.2);
        }
        for i in &expr.linear_combinations {
            self.fold(i.1);
        }
    }

    fn fold_brillig_inputs<F: AcirField>(&mut self, inputs: &[BrilligInputs<F>]) {
        for input in inputs {
            match input {
                BrilligInputs::Single(expr) => {
                    self.fold_expr(expr);
                }
                BrilligInputs::Array(exprs) => {
                    for expr in exprs {
                        self.fold_expr(expr);
                    }
                }
                BrilligInputs::MemoryArray(_) => {}
            }
        }
    }

    fn fold_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) {
        for output in outputs {
            match output {
                BrilligOutputs::Simple(w) => {
                    self.fold(*w);
                }
                BrilligOutputs::Array(ws) => self.fold_many(ws.iter()),
            }
        }
    }

    fn fold_blackbox<F: AcirField>(&mut self, call: &BlackBoxFuncCall<F>) {
        match call {
            BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => {
                self.fold_inputs(inputs.as_slice());
                self.fold_inputs(iv.as_slice());
                self.fold_inputs(key.as_slice());
                self.fold_many(outputs.iter());
            }
            BlackBoxFuncCall::AND { lhs, rhs, output, .. } => {
                self.fold_input(lhs);
                self.fold_input(rhs);
                self.fold(*output);
            }
            BlackBoxFuncCall::XOR { lhs, rhs, output, .. } => {
                self.fold_input(lhs);
                self.fold_input(rhs);
                self.fold(*output);
            }
            BlackBoxFuncCall::RANGE { input, .. } => {
                self.fold_input(input);
            }
            BlackBoxFuncCall::Blake2s { inputs, outputs } => {
                self.fold_inputs(inputs.as_slice());
                self.fold_many(outputs.iter());
            }
            BlackBoxFuncCall::Blake3 { inputs, outputs } => {
                self.fold_inputs(inputs.as_slice());
                self.fold_many(outputs.iter());
            }
            BlackBoxFuncCall::EcdsaSecp256k1 {
                public_key_x,
                public_key_y,
                signature,
                hashed_message,
                output,
                predicate,
            } => {
                self.fold_inputs(public_key_x.as_slice());
                self.fold_inputs(public_key_y.as_slice());
                self.fold_inputs(signature.as_slice());
                self.fold_inputs(hashed_message.as_slice());
                self.fold(*output);
                self.fold_input(predicate);
            }
            BlackBoxFuncCall::EcdsaSecp256r1 {
                public_key_x,
                public_key_y,
                signature,
                hashed_message,
                output,
                predicate,
            } => {
                self.fold_inputs(public_key_x.as_slice());
                self.fold_inputs(public_key_y.as_slice());
                self.fold_inputs(signature.as_slice());
                self.fold_inputs(hashed_message.as_slice());
                self.fold(*output);
                self.fold_input(predicate);
            }
            BlackBoxFuncCall::MultiScalarMul { points, scalars, predicate, outputs } => {
                self.fold_inputs(points.as_slice());
                self.fold_inputs(scalars.as_slice());
                self.fold_input(predicate);
                let (x, y, i) = outputs;
                self.fold(*x);
                self.fold(*y);
                self.fold(*i);
            }
            BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, predicate, outputs } => {
                self.fold_inputs(input1.as_slice());
                self.fold_inputs(input2.as_slice());
                self.fold_input(predicate);
                let (x, y, i) = outputs;
                self.fold(*x);
                self.fold(*y);
                self.fold(*i);
            }
            BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => {
                self.fold_inputs(inputs.as_slice());
                self.fold_many(outputs.iter());
            }
            BlackBoxFuncCall::RecursiveAggregation {
                verification_key,
                proof,
                public_inputs,
                key_hash,
                proof_type: _,
                predicate,
            } => {
                self.fold_inputs(verification_key.as_slice());
                self.fold_inputs(proof.as_slice());
                self.fold_inputs(public_inputs.as_slice());
                self.fold_input(key_hash);
                self.fold_input(predicate);
            }
            BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs } => {
                self.fold_inputs(inputs.as_slice());
                self.fold_many(outputs.iter());
            }
            BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => {
                self.fold_inputs(inputs.as_slice());
                self.fold_inputs(hash_values.as_slice());
                self.fold_many(outputs.iter());
            }
        }
    }

    fn fold_inputs<F: AcirField>(&mut self, inputs: &[FunctionInput<F>]) {
        for input in inputs {
            self.fold_input(input);
        }
    }

    fn fold_input<F: AcirField>(&mut self, input: &FunctionInput<F>) {
        if let FunctionInput::Witness(witness) = input {
            self.fold(*witness);
        }
    }
}