acvm/compiler/optimizers/common_subexpression/
csat.rs

1use std::{cmp::Ordering, collections::HashSet};
2
3use acir::{
4    AcirField,
5    native_types::{Expression, Witness},
6};
7use indexmap::IndexMap;
8
9/// Minimum width accepted by the `CSatTransformer`.
10pub(crate) const MIN_EXPRESSION_WIDTH: usize = 3;
11
12/// A transformer which processes any [`Expression`]s to break them up such that they
13/// fit within the backend's width.
14///
15/// This is done by creating intermediate variables to hold partial calculations and then combining them
16/// to calculate the original expression.
17///
18/// Pre-Condition:
19/// - General Optimizer must run before this pass
20pub(crate) struct CSatTransformer {
21    width: usize,
22    /// Track the witness that can be solved
23    solvable_witness: HashSet<Witness>,
24}
25
26impl CSatTransformer {
27    /// Create an optimizer with a given width.
28    ///
29    /// Panics if `width` is less than `MIN_EXPRESSION_WIDTH`.
30    pub(crate) fn new(width: usize) -> CSatTransformer {
31        assert!(width >= MIN_EXPRESSION_WIDTH, "width has to be at least {MIN_EXPRESSION_WIDTH}");
32
33        CSatTransformer { width, solvable_witness: HashSet::new() }
34    }
35
36    /// Check if the equation 'expression=0' can be solved, and if yes, add the solved witness to set of solvable witness
37    fn try_solve<F>(&mut self, opcode: &Expression<F>) {
38        let mut unresolved = Vec::new();
39        for (_, w1, w2) in &opcode.mul_terms {
40            if !self.solvable_witness.contains(w1) {
41                unresolved.push(w1);
42                if !self.solvable_witness.contains(w2) {
43                    return;
44                }
45            }
46            if !self.solvable_witness.contains(w2) {
47                unresolved.push(w2);
48                if !self.solvable_witness.contains(w1) {
49                    return;
50                }
51            }
52        }
53        for (_, w) in &opcode.linear_combinations {
54            if !self.solvable_witness.contains(w) {
55                unresolved.push(w);
56            }
57        }
58        if unresolved.len() == 1 {
59            self.mark_solvable(*unresolved[0]);
60        }
61    }
62
63    /// Adds the witness to set of solvable witness
64    pub(crate) fn mark_solvable(&mut self, witness: Witness) {
65        self.solvable_witness.insert(witness);
66    }
67
68    /// Transform the input arithmetic expression into a new one having the correct 'width'
69    /// by creating intermediate variables as needed.
70    /// Having the correct width means:
71    /// - it has at most one multiplicative term
72    /// - it uses at most 'width-1' witness linear combination terms, to account for the new intermediate variable
73    pub(crate) fn transform<F: AcirField>(
74        &mut self,
75        opcode: Expression<F>,
76        intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
77        num_witness: &mut u32,
78    ) -> Expression<F> {
79        // Here we create intermediate variables and constrain them to be equal to any subset of the polynomial that can be represented as a full opcode
80        let opcode =
81            self.full_opcode_scan_optimization(opcode, intermediate_variables, num_witness);
82        // The last optimization to do is to create intermediate variables in order to flatten the fan-in and the amount of mul terms
83        // If a opcode has more than one mul term. We may need an intermediate variable for each one. Since not every variable will need to link to
84        // the mul term, we could possibly do it that way.
85        // We wil call this a partial opcode scan optimization which will result in the opcodes being able to fit into the correct width
86        let mut opcode =
87            self.partial_opcode_scan_optimization(opcode, intermediate_variables, num_witness);
88        opcode.sort();
89        self.try_solve(&opcode);
90        opcode
91    }
92
93    // This optimization will search for combinations of terms which can be represented in a single assert-zero opcode
94    // Case 1 : qM * wL * wR + qL * wL + qR * wR + qO * wO + qC
95    // This polynomial does not require any further optimizations, it can be safely represented in one opcode
96    // ie a polynomial with 1 mul(bi-variate) term and 3 (univariate) terms where 2 of those terms match the bivariate term
97    // wL and wR, we can represent it in one opcode
98    // GENERALIZED for WIDTH: instead of the number 3, we use `WIDTH`
99    //
100    //
101    // Case 2: qM * wL * wR + qL * wL + qR * wR + qO * wO + qC + qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2
102    // This polynomial cannot be represented using one assert-zero opcode.
103    //
104    // This algorithm will first extract the first full opcode(if possible):
105    // t = qM * wL * wR + qL * wL + qR * wR + qO * wO + qC
106    //
107    // The polynomial now looks like so t + qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2
108    // This polynomial cannot be represented using one assert-zero opcode.
109    //
110    // This algorithm will then extract the second full opcode(if possible):
111    // t2 = qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2
112    //
113    // The polynomial now looks like so t + t2
114    // We can no longer extract another full opcode, hence the algorithm terminates. Creating two intermediate variables t and t2.
115    // This stage of preprocessing does not guarantee that all polynomials can fit into a opcode. It only guarantees that all full opcodes have been extracted from each polynomial
116    fn full_opcode_scan_optimization<F: AcirField>(
117        &mut self,
118        mut opcode: Expression<F>,
119        intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
120        num_witness: &mut u32,
121    ) -> Expression<F> {
122        // We pass around this intermediate variable IndexMap, so that we do not create intermediate variables that we have created before
123        // One instance where this might happen is t1 = wL * wR and t2 = wR * wL
124
125        // First check that this is not a simple opcode which does not need optimization
126        //
127        // If the opcode only has one mul term, then this algorithm cannot optimize it any further
128        // Either it can be represented in a single arithmetic equation or its fan-in is too large and we need intermediate variables for those
129        // Large-fan-in optimization is not this algorithm's purpose.
130        // If the opcode has 0 mul terms, then it is an add opcode and similarly it can either fit into a single assert-zero opcode or it has a large fan-in
131        if opcode.mul_terms.len() <= 1 {
132            return opcode;
133        }
134
135        // We now know that this opcode has multiple mul terms and can possibly be simplified into multiple full opcodes
136        // We need to create a (w_l, w_r) IndexMap and then check the simplified fan-in to verify if we have terms both with w_l and w_r
137        // In general, we can then push more terms into the opcode until we are at width-1 then the last variable will be the intermediate variable
138        //
139
140        // This will be our new opcode which will be equal to `self` except we will have intermediate variables that will be constrained to any
141        // subset of the terms that can be represented as full opcodes
142        let mut new_opcode = Expression::default();
143        let mut remaining_mul_terms = Vec::with_capacity(opcode.mul_terms.len());
144        for (scale, w_l, w_r) in opcode.mul_terms {
145            // We want to layout solvable intermediate variables, if we cannot solve one of the witnesses
146            // that means the intermediate opcode will not be immediately solvable
147            if !self.solvable_witness.contains(&w_l) || !self.solvable_witness.contains(&w_r) {
148                remaining_mul_terms.push((scale, w_l, w_r));
149                continue;
150            }
151
152            // Check if this (scale, w_l, w_r) triple is present in the simplified fan-in
153            // We are assuming that the fan-in/fan-out has been simplified.
154            // Note this function is not public, and can only be called within the optimize method, so this guarantee will always hold
155            let index_wl =
156                opcode.linear_combinations.iter().position(|(_scale, witness)| *witness == w_l);
157            let index_wr =
158                opcode.linear_combinations.iter().position(|(_scale, witness)| *witness == w_r);
159
160            match (index_wl, index_wr) {
161                (None, _) | (_, None) => {
162                    // This means that the polynomial does not contain both terms
163                    // Just push the Qm term as it cannot form a full opcode
164                    new_opcode.mul_terms.push((scale, w_l, w_r));
165                }
166                (Some(x), Some(y)) => {
167                    // This means that we can form a full opcode with this Qm term
168
169                    // First fetch the left and right wires which match the mul term
170                    let left_wire_term = opcode.linear_combinations[x];
171                    let right_wire_term = opcode.linear_combinations[y];
172
173                    // Lets create an intermediate opcode to store this full opcode
174                    //
175                    let mut intermediate_opcode = Expression::default();
176                    intermediate_opcode.mul_terms.push((scale, w_l, w_r));
177
178                    // Add the left and right wires
179                    intermediate_opcode.linear_combinations.push(left_wire_term);
180                    intermediate_opcode.linear_combinations.push(right_wire_term);
181                    // Remove the left and right wires so we do not re-add them
182                    match x.cmp(&y) {
183                        Ordering::Greater => {
184                            opcode.linear_combinations.remove(x);
185                            opcode.linear_combinations.remove(y);
186                        }
187                        Ordering::Less => {
188                            opcode.linear_combinations.remove(y);
189                            opcode.linear_combinations.remove(x);
190                        }
191                        Ordering::Equal => {
192                            opcode.linear_combinations.remove(x);
193                            intermediate_opcode.linear_combinations.pop();
194                        }
195                    }
196
197                    let used_space = intermediate_opcode.linear_combinations.len();
198                    assert!(used_space < self.width);
199
200                    // Now we have used up "used_space" spaces in our assert-zero opcode. The width now dictates how many more we can add
201                    let mut remaining_space = self.width - used_space - 1; // We minus 1 because we need an extra space to contain the intermediate variable
202                    // Keep adding terms until we have no more left, or we reach the width
203                    let mut remaining_linear_terms =
204                        Vec::with_capacity(opcode.linear_combinations.len());
205                    while remaining_space > 0 {
206                        if let Some(wire_term) = opcode.linear_combinations.pop() {
207                            // Add this element into the new opcode
208                            if self.solvable_witness.contains(&wire_term.1) {
209                                intermediate_opcode.linear_combinations.push(wire_term);
210                                remaining_space -= 1;
211                            } else {
212                                remaining_linear_terms.push(wire_term);
213                            }
214                        } else {
215                            // No more usable elements left in the old opcode
216                            break;
217                        }
218                    }
219                    opcode.linear_combinations.extend(remaining_linear_terms);
220
221                    // Constrain this intermediate_opcode to be equal to the temp variable by adding it into the IndexMap
222                    // We need a unique name for our intermediate variable
223                    // TODO(https://github.com/noir-lang/noir/issues/10192): Another optimization, which could be applied in another algorithm
224                    // If two opcodes have a large fan-in/out and they share a few common terms, then we should create intermediate variables for them
225                    // Do some sort of subset matching algorithm for this on the terms of the polynomial
226                    let intermediate_var = Self::get_or_create_intermediate_var(
227                        intermediate_variables,
228                        intermediate_opcode,
229                        num_witness,
230                    );
231
232                    // Add intermediate variable to the new opcode instead of the full opcode
233                    self.mark_solvable(intermediate_var.1);
234                    new_opcode.linear_combinations.push(intermediate_var);
235                }
236            }
237        }
238
239        // Add the rest of the elements back into the new_opcode
240        new_opcode.mul_terms.extend(remaining_mul_terms);
241        new_opcode.linear_combinations.extend(opcode.linear_combinations);
242        new_opcode.q_c = opcode.q_c;
243        new_opcode.sort();
244        new_opcode
245    }
246
247    /// Normalize an expression by dividing it by its first coefficient
248    /// The first coefficient here means coefficient of the first linear term, or of the first quadratic term if no linear terms exist.
249    /// This function panics if the input expression is constant or if the first coefficient's inverse is F::zero()
250    fn normalize<F: AcirField>(mut expr: Expression<F>) -> (F, Expression<F>) {
251        expr.sort();
252        let a = if !expr.linear_combinations.is_empty() {
253            expr.linear_combinations[0].0
254        } else {
255            expr.mul_terms[0].0
256        };
257        let a_inverse = a.inverse();
258        assert!(a_inverse != F::zero(), "normalize: the first coefficient is non-invertible");
259        (a, &expr * a_inverse)
260    }
261
262    /// Get or generate a scaled intermediate witness which is equal to the provided expression
263    /// The sets of previously generated witness and their (normalized) expression is cached in the intermediate_variables map
264    /// If there is no cache hit, we generate a new witness (and add the expression to the cache)
265    /// else, we return the cached witness along with the scaling factor so it is equal to the provided expression
266    fn get_or_create_intermediate_var<F: AcirField>(
267        intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
268        expr: Expression<F>,
269        num_witness: &mut u32,
270    ) -> (F, Witness) {
271        let (k, normalized_expr) = Self::normalize(expr);
272
273        if intermediate_variables.contains_key(&normalized_expr) {
274            let (l, iv) = intermediate_variables[&normalized_expr];
275            assert!(
276                l != F::zero(),
277                "get_or_create_intermediate_var: attempting to divide l by F::zero()"
278            );
279            (k / l, iv)
280        } else {
281            let inter_var = Witness(*num_witness);
282            *num_witness += 1;
283            // Add intermediate opcode and variable to map
284            intermediate_variables.insert(normalized_expr, (k, inter_var));
285            (F::one(), inter_var)
286        }
287    }
288
289    // A partial opcode scan optimization aim to create intermediate variables in order to compress the polynomial
290    // So that it fits within the given width
291    // Note that this opcode follows the full opcode scan optimization.
292    // We define the partial width as equal to the full width - 2.
293    // This is because two of our variables cannot be used as they are linked to the multiplication terms
294    // Example: qM1 * wL1 * wR2 + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
295    // One thing to note is that the multiplication wires do not match any of the fan-in/out wires. This is guaranteed as we have
296    // just completed the full opcode optimization algorithm.
297    //
298    // Actually we can optimize in two ways here: We can create an intermediate variable which is equal to the fan-in terms
299    // t = qL1 * wL3 + qR1 * wR4 -> width = 3
300    // This `t` value can only use width - 1 terms
301    // The opcode now looks like: qM1 * wL1 * wR2 + t + qR2 * wR5+ qO1 * wO5 + qC
302    // But this is still not acceptable since wR5 is not wR2, so we need another intermediate variable
303    // t2 = t + qR2 * wR5
304    //
305    // The opcode now looks like: qM1 * wL1 * wR2 + t2 + qO1 * wO5 + qC
306    // This is still not good, so we do it one more time:
307    // t3 = t2 + qO1 * wO5
308    // The opcode now looks like: qM1 * wL1 * wR2 + t3 + qC
309    //
310    // Another strategy is to create a temporary variable for the multiplier term and then we can see it as a term in the fan-in
311    //
312    // Same Example: qM1 * wL1 * wR2 + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
313    // t = qM1 * wL1 * wR2
314    // The opcode now looks like: t + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
315    // Still assuming width-3, we still need to use width-1 terms for the intermediate variables, however we can stop at an earlier stage because
316    // the opcode does not need the multiplier term to match with any of the fan-in terms
317    // t2 = t + qL1 * wL3
318    // The opcode now looks like: t2 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
319    // t3 = t2 + qR1 * wR4
320    // The opcode now looks like: t3 + qR2 * wR5 + qO1 * wO5 + qC
321    // This took the same amount of opcodes, but which one is better when the width increases? Compute this and maybe do both optimizations
322    // naming : partial_opcode_mul_first_opt and partial_opcode_fan_first_opt
323    // Also remember that since we did full opcode scan, there is no way we can have a non-zero mul term along with the wL and wR terms being non-zero
324    //
325    // Cases, a lot of mul terms, a lot of fan-in terms, 50/50
326    fn partial_opcode_scan_optimization<F: AcirField>(
327        &mut self,
328        mut opcode: Expression<F>,
329        intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
330        num_witness: &mut u32,
331    ) -> Expression<F> {
332        // We will go for the easiest route, which is to convert all multiplications into additions using intermediate variables
333        // Then use intermediate variables again to squash the fan-in, so that it can fit into the appropriate width
334
335        // First check if this polynomial actually needs a partial opcode optimization
336        // There is the chance that it fits perfectly within the assert-zero opcode
337        if fits_in_one_identity(&opcode, self.width) {
338            return opcode;
339        }
340
341        // Create Intermediate variables for the multiplication opcodes
342        let mut remaining_mul_terms = Vec::with_capacity(opcode.mul_terms.len());
343        for (scale, w_l, w_r) in opcode.mul_terms {
344            if self.solvable_witness.contains(&w_l) && self.solvable_witness.contains(&w_r) {
345                let mut intermediate_opcode = Expression::default();
346
347                // Push mul term into the opcode
348                intermediate_opcode.mul_terms.push((scale, w_l, w_r));
349                // Get an intermediate variable which squashes the multiplication term
350                let intermediate_var = Self::get_or_create_intermediate_var(
351                    intermediate_variables,
352                    intermediate_opcode,
353                    num_witness,
354                );
355
356                // Add intermediate variable as a part of the fan-in for the original opcode
357                opcode.linear_combinations.push(intermediate_var);
358                self.mark_solvable(intermediate_var.1);
359            } else {
360                remaining_mul_terms.push((scale, w_l, w_r));
361            }
362        }
363
364        // Remove all of the mul terms as we have intermediate variables to represent them now
365        opcode.mul_terms = remaining_mul_terms;
366
367        // We now only have a polynomial with only fan-in/fan-out terms i.e. terms of the form Ax + By + Cd + ...
368        // Lets create intermediate variables if all of them cannot fit into the width
369        //
370        // If the polynomial fits perfectly within the given width, we are finished
371        if opcode.linear_combinations.len() <= self.width {
372            return opcode;
373        }
374
375        // Stores the intermediate variables that are used to
376        // reduce the fan in.
377        let mut added_vars = Vec::new();
378
379        while opcode.linear_combinations.len() > self.width {
380            // Collect as many terms up to the given width-1 and constrain them to an intermediate variable
381            let mut intermediate_opcode = Expression::default();
382
383            let mut remaining_linear_terms = Vec::with_capacity(opcode.linear_combinations.len());
384
385            for term in opcode.linear_combinations {
386                if self.solvable_witness.contains(&term.1)
387                    && intermediate_opcode.linear_combinations.len() < self.width - 1
388                {
389                    intermediate_opcode.linear_combinations.push(term);
390                } else {
391                    remaining_linear_terms.push(term);
392                }
393            }
394            opcode.linear_combinations = remaining_linear_terms;
395            let not_full = intermediate_opcode.linear_combinations.len() < self.width - 1;
396            if intermediate_opcode.linear_combinations.len() > 1 {
397                let intermediate_var = Self::get_or_create_intermediate_var(
398                    intermediate_variables,
399                    intermediate_opcode,
400                    num_witness,
401                );
402                self.mark_solvable(intermediate_var.1);
403                added_vars.push(intermediate_var);
404            } else {
405                // Put back the single term that couldn't form an intermediate variable
406                opcode.linear_combinations.extend(intermediate_opcode.linear_combinations);
407            }
408            // The intermediate opcode is not full, but the opcode still has too many terms
409            if not_full && opcode.linear_combinations.len() > self.width {
410                unreachable!("Could not reduce the expression");
411            }
412        }
413
414        // Add back the intermediate variables to
415        // keep consistency with the original equation.
416        opcode.linear_combinations.extend(added_vars);
417        self.partial_opcode_scan_optimization(opcode, intermediate_variables, num_witness)
418    }
419}
420
421/// Checks if this expression can fit into one arithmetic identity
422fn fits_in_one_identity<F: AcirField>(expr: &Expression<F>, width: usize) -> bool {
423    // A Polynomial with more than one mul term cannot fit into one opcode
424    if expr.mul_terms.len() > 1 {
425        return false;
426    }
427
428    expr.width() <= width
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use acir::FieldElement;
435
436    #[test]
437    fn simple_reduction_smoke_test() {
438        let a = Witness(0);
439        let b = Witness(1);
440        let c = Witness(2);
441        let d = Witness(3);
442
443        // a = b + c + d;
444        let opcode_a = Expression::from_str(&format!("{a} - {b} - {c} - {d}")).unwrap();
445
446        let mut intermediate_variables: IndexMap<
447            Expression<FieldElement>,
448            (FieldElement, Witness),
449        > = IndexMap::new();
450
451        let mut num_witness = 4;
452
453        let mut optimizer = CSatTransformer::new(3);
454        optimizer.mark_solvable(b);
455        optimizer.mark_solvable(c);
456        optimizer.mark_solvable(d);
457        let got_optimized_opcode_a =
458            optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness);
459
460        // a = b + c + d => a - b - c - d = 0
461        // For width3, the result becomes:
462        // a - d + e = 0
463        // - c - b  - e = 0
464        //
465        // a - b + e = 0
466        let e = Witness(4);
467        let expected_optimized_opcode_a =
468            Expression::from_str(&format!("{a} - {d} + {e}")).unwrap();
469
470        assert_eq!(expected_optimized_opcode_a, got_optimized_opcode_a);
471
472        assert_eq!(intermediate_variables.len(), 1);
473
474        // e = - c - b
475        let expected_intermediate_opcode = Expression::from_str(&format!("-{c} - {b}")).unwrap();
476        let (_, normalized_opcode) = CSatTransformer::normalize(expected_intermediate_opcode);
477        assert!(intermediate_variables.contains_key(&normalized_opcode));
478        assert_eq!(intermediate_variables[&normalized_opcode].1, e);
479    }
480
481    #[test]
482    fn stepwise_reduction_test() {
483        let a = Witness(0);
484        let b = Witness(1);
485        let c = Witness(2);
486        let d = Witness(3);
487        let e = Witness(4);
488
489        // a = b + c + d + e;
490        let opcode_a = Expression::from_str(&format!("-{a} + {b} + {c} + {d} + {e}")).unwrap();
491
492        let mut intermediate_variables: IndexMap<
493            Expression<FieldElement>,
494            (FieldElement, Witness),
495        > = IndexMap::new();
496
497        let mut num_witness = 4;
498
499        let mut optimizer = CSatTransformer::new(3);
500        optimizer.mark_solvable(a);
501        optimizer.mark_solvable(c);
502        optimizer.mark_solvable(d);
503        optimizer.mark_solvable(e);
504        let got_optimized_opcode_a =
505            optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness);
506
507        // Since b is not known, it cannot be put inside intermediate opcodes, so it must belong to the transformed opcode.
508        let contains_b = got_optimized_opcode_a.linear_combinations.iter().any(|(_, w)| *w == b);
509        assert!(contains_b);
510    }
511
512    #[test]
513    fn recognize_expr_with_single_shared_witness_which_fits_in_single_identity() {
514        // Regression test for an expression which Zac found which should have been preserved but
515        // was being split into two expressions.
516        let expr = Expression::from_str("-555*w8*w10 + w10 + w11 - w13").unwrap();
517        assert!(fits_in_one_identity(&expr, 4));
518    }
519
520    #[test]
521    #[should_panic(expected = "normalize: the first coefficient is non-invertible")]
522    fn normalize_on_zero_linear_combination_panics() {
523        let expr = Expression {
524            mul_terms: vec![],
525            linear_combinations: vec![(FieldElement::zero(), Witness(0))],
526            q_c: FieldElement::zero(),
527        };
528        CSatTransformer::normalize(expr);
529    }
530
531    #[test]
532    #[should_panic(expected = "normalize: the first coefficient is non-invertible")]
533    fn normalize_on_zero_mul_term_scale_panics() {
534        let expr = Expression {
535            mul_terms: vec![(FieldElement::zero(), Witness(0), Witness(1))],
536            linear_combinations: vec![],
537            q_c: FieldElement::zero(),
538        };
539        CSatTransformer::normalize(expr);
540    }
541
542    #[test]
543    #[should_panic(
544        expected = "get_or_create_intermediate_var: attempting to divide l by F::zero()"
545    )]
546    fn get_or_create_intermediate_var_with_zero_panics() {
547        let expr = Expression {
548            mul_terms: vec![(FieldElement::one(), Witness(0), Witness(1))],
549            linear_combinations: vec![],
550            q_c: FieldElement::zero(),
551        };
552
553        let mut intermediate_variables = IndexMap::new();
554        intermediate_variables.insert(expr.clone(), (FieldElement::zero(), Witness(0)));
555
556        let mut num_witness = 2;
557
558        CSatTransformer::get_or_create_intermediate_var(
559            &mut intermediate_variables,
560            expr,
561            &mut num_witness,
562        );
563    }
564
565    #[test]
566    fn full_opcode_scan_optimization_extracts_full_opcodes() {
567        // Expression: x*x + a*b + x + a + b + c + d + e + f + g
568        //
569        // With width=3 and two mul terms, full_opcode_scan_optimization extracts each
570        // mul term together with 2 linear terms into an intermediate variable:
571        //   t0 = x*x + x + g      (Witness 8)
572        //   t1 = a*b + a + b      (Witness 9)
573        //
574        // The remaining expression becomes: t0 + t1 + c + d + e + f
575        let x = Witness(0);
576        let a = Witness(1);
577        let b = Witness(2);
578        let c = Witness(3);
579        let d = Witness(4);
580        let e = Witness(5);
581        let f = Witness(6);
582        let g = Witness(7);
583
584        let opcode = Expression {
585            mul_terms: vec![(FieldElement::one(), x, x), (FieldElement::one(), a, b)],
586            linear_combinations: vec![
587                (FieldElement::one(), x),
588                (FieldElement::one(), a),
589                (FieldElement::one(), b),
590                (FieldElement::one(), c),
591                (FieldElement::one(), d),
592                (FieldElement::one(), e),
593                (FieldElement::one(), f),
594                (FieldElement::one(), g),
595            ],
596            q_c: FieldElement::zero(),
597        };
598
599        let mut intermediate_variables: IndexMap<
600            Expression<FieldElement>,
601            (FieldElement, Witness),
602        > = IndexMap::new();
603        let mut num_witness = 8u32;
604
605        let mut optimizer = CSatTransformer::new(3);
606        for w in [x, a, b, c, d, e, f, g] {
607            optimizer.mark_solvable(w);
608        }
609
610        let result = optimizer.full_opcode_scan_optimization(
611            opcode,
612            &mut intermediate_variables,
613            &mut num_witness,
614        );
615
616        // Both mul terms were replaced by intermediate variables; no mul terms remain.
617        assert!(result.mul_terms.is_empty(), "all mul terms should be absorbed");
618
619        // Each intermediate variable is full: it has 2 linear terms.
620        for intermediate in intermediate_variables.keys() {
621            assert!(
622                intermediate.mul_terms.len() == 1,
623                "intermediate variables should be full opcodes"
624            );
625            assert!(
626                intermediate.linear_combinations.len() == 2,
627                "intermediate variables should be full opcodes"
628            );
629        }
630        // Remaining linear terms: c, d, e, f+ t0, t1 (intermediate vars).
631        assert_eq!(result.linear_combinations.len(), 6);
632    }
633
634    #[test]
635    #[should_panic(expected = "Could not reduce the expression")]
636    fn single_solvable_term_in_intermediate_opcode_is_preserved() {
637        // Test the case when len() is 1 in the line: 'if intermediate_opcode.linear_combinations.len() > 1 {'
638        // Setup: width is 3, 4 terms [a, b, c, d], only 'a' is solvable.
639        //
640        // Because only 'a' is solvable, the intermediate_opcode will be [a], so its len is 1.
641        // In this case, 'a' should be added back to the opcode which makes it larger than the width and trigger the panic.
642
643        let a = Witness(0); // solvable
644        let b = Witness(1); // unsolvable
645        let c = Witness(2); // unsolvable
646        let d = Witness(3); // unsolvable
647
648        let opcode = Expression {
649            mul_terms: vec![],
650            linear_combinations: vec![
651                (FieldElement::one(), a),
652                (FieldElement::one(), b),
653                (FieldElement::one(), c),
654                (FieldElement::one(), d),
655            ],
656            q_c: FieldElement::zero(),
657        };
658
659        let mut intermediate_variables: IndexMap<
660            Expression<FieldElement>,
661            (FieldElement, Witness),
662        > = IndexMap::new();
663
664        let mut num_witness = 4;
665
666        let mut optimizer = CSatTransformer::new(3);
667        optimizer.mark_solvable(a);
668
669        let _ = optimizer.transform(opcode, &mut intermediate_variables, &mut num_witness);
670    }
671}