acvm/compiler/optimizers/common_subexpression/csat.rs
1use std::{cmp::Ordering, collections::HashSet};
2
3use acir::{
4 AcirField,
5 native_types::{Expression, Witness},
6};
7use indexmap::IndexMap;
8
9/// Minimum width accepted by the `CSatTransformer`.
10pub(crate) const MIN_EXPRESSION_WIDTH: usize = 3;
11
12/// A transformer which processes any [`Expression`]s to break them up such that they
13/// fit within the backend's width.
14///
15/// This is done by creating intermediate variables to hold partial calculations and then combining them
16/// to calculate the original expression.
17///
18/// Pre-Condition:
19/// - General Optimizer must run before this pass
20pub(crate) struct CSatTransformer {
21 width: usize,
22 /// Track the witness that can be solved
23 solvable_witness: HashSet<Witness>,
24}
25
26impl CSatTransformer {
27 /// Create an optimizer with a given width.
28 ///
29 /// Panics if `width` is less than `MIN_EXPRESSION_WIDTH`.
30 pub(crate) fn new(width: usize) -> CSatTransformer {
31 assert!(width >= MIN_EXPRESSION_WIDTH, "width has to be at least {MIN_EXPRESSION_WIDTH}");
32
33 CSatTransformer { width, solvable_witness: HashSet::new() }
34 }
35
36 /// Check if the equation 'expression=0' can be solved, and if yes, add the solved witness to set of solvable witness
37 fn try_solve<F>(&mut self, opcode: &Expression<F>) {
38 let mut unresolved = Vec::new();
39 for (_, w1, w2) in &opcode.mul_terms {
40 if !self.solvable_witness.contains(w1) {
41 unresolved.push(w1);
42 if !self.solvable_witness.contains(w2) {
43 return;
44 }
45 }
46 if !self.solvable_witness.contains(w2) {
47 unresolved.push(w2);
48 if !self.solvable_witness.contains(w1) {
49 return;
50 }
51 }
52 }
53 for (_, w) in &opcode.linear_combinations {
54 if !self.solvable_witness.contains(w) {
55 unresolved.push(w);
56 }
57 }
58 if unresolved.len() == 1 {
59 self.mark_solvable(*unresolved[0]);
60 }
61 }
62
63 /// Adds the witness to set of solvable witness
64 pub(crate) fn mark_solvable(&mut self, witness: Witness) {
65 self.solvable_witness.insert(witness);
66 }
67
68 /// Transform the input arithmetic expression into a new one having the correct 'width'
69 /// by creating intermediate variables as needed.
70 /// Having the correct width means:
71 /// - it has at most one multiplicative term
72 /// - it uses at most 'width-1' witness linear combination terms, to account for the new intermediate variable
73 pub(crate) fn transform<F: AcirField>(
74 &mut self,
75 opcode: Expression<F>,
76 intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
77 num_witness: &mut u32,
78 ) -> Expression<F> {
79 // Here we create intermediate variables and constrain them to be equal to any subset of the polynomial that can be represented as a full opcode
80 let opcode =
81 self.full_opcode_scan_optimization(opcode, intermediate_variables, num_witness);
82 // The last optimization to do is to create intermediate variables in order to flatten the fan-in and the amount of mul terms
83 // If a opcode has more than one mul term. We may need an intermediate variable for each one. Since not every variable will need to link to
84 // the mul term, we could possibly do it that way.
85 // We wil call this a partial opcode scan optimization which will result in the opcodes being able to fit into the correct width
86 let mut opcode =
87 self.partial_opcode_scan_optimization(opcode, intermediate_variables, num_witness);
88 opcode.sort();
89 self.try_solve(&opcode);
90 opcode
91 }
92
93 // This optimization will search for combinations of terms which can be represented in a single assert-zero opcode
94 // Case 1 : qM * wL * wR + qL * wL + qR * wR + qO * wO + qC
95 // This polynomial does not require any further optimizations, it can be safely represented in one opcode
96 // ie a polynomial with 1 mul(bi-variate) term and 3 (univariate) terms where 2 of those terms match the bivariate term
97 // wL and wR, we can represent it in one opcode
98 // GENERALIZED for WIDTH: instead of the number 3, we use `WIDTH`
99 //
100 //
101 // Case 2: qM * wL * wR + qL * wL + qR * wR + qO * wO + qC + qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2
102 // This polynomial cannot be represented using one assert-zero opcode.
103 //
104 // This algorithm will first extract the first full opcode(if possible):
105 // t = qM * wL * wR + qL * wL + qR * wR + qO * wO + qC
106 //
107 // The polynomial now looks like so t + qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2
108 // This polynomial cannot be represented using one assert-zero opcode.
109 //
110 // This algorithm will then extract the second full opcode(if possible):
111 // t2 = qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2
112 //
113 // The polynomial now looks like so t + t2
114 // We can no longer extract another full opcode, hence the algorithm terminates. Creating two intermediate variables t and t2.
115 // This stage of preprocessing does not guarantee that all polynomials can fit into a opcode. It only guarantees that all full opcodes have been extracted from each polynomial
116 fn full_opcode_scan_optimization<F: AcirField>(
117 &mut self,
118 mut opcode: Expression<F>,
119 intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
120 num_witness: &mut u32,
121 ) -> Expression<F> {
122 // We pass around this intermediate variable IndexMap, so that we do not create intermediate variables that we have created before
123 // One instance where this might happen is t1 = wL * wR and t2 = wR * wL
124
125 // First check that this is not a simple opcode which does not need optimization
126 //
127 // If the opcode only has one mul term, then this algorithm cannot optimize it any further
128 // Either it can be represented in a single arithmetic equation or its fan-in is too large and we need intermediate variables for those
129 // Large-fan-in optimization is not this algorithm's purpose.
130 // If the opcode has 0 mul terms, then it is an add opcode and similarly it can either fit into a single assert-zero opcode or it has a large fan-in
131 if opcode.mul_terms.len() <= 1 {
132 return opcode;
133 }
134
135 // We now know that this opcode has multiple mul terms and can possibly be simplified into multiple full opcodes
136 // We need to create a (w_l, w_r) IndexMap and then check the simplified fan-in to verify if we have terms both with w_l and w_r
137 // In general, we can then push more terms into the opcode until we are at width-1 then the last variable will be the intermediate variable
138 //
139
140 // This will be our new opcode which will be equal to `self` except we will have intermediate variables that will be constrained to any
141 // subset of the terms that can be represented as full opcodes
142 let mut new_opcode = Expression::default();
143 let mut remaining_mul_terms = Vec::with_capacity(opcode.mul_terms.len());
144 for (scale, w_l, w_r) in opcode.mul_terms {
145 // We want to layout solvable intermediate variables, if we cannot solve one of the witnesses
146 // that means the intermediate opcode will not be immediately solvable
147 if !self.solvable_witness.contains(&w_l) || !self.solvable_witness.contains(&w_r) {
148 remaining_mul_terms.push((scale, w_l, w_r));
149 continue;
150 }
151
152 // Check if this (scale, w_l, w_r) triple is present in the simplified fan-in
153 // We are assuming that the fan-in/fan-out has been simplified.
154 // Note this function is not public, and can only be called within the optimize method, so this guarantee will always hold
155 let index_wl =
156 opcode.linear_combinations.iter().position(|(_scale, witness)| *witness == w_l);
157 let index_wr =
158 opcode.linear_combinations.iter().position(|(_scale, witness)| *witness == w_r);
159
160 match (index_wl, index_wr) {
161 (None, _) | (_, None) => {
162 // This means that the polynomial does not contain both terms
163 // Just push the Qm term as it cannot form a full opcode
164 new_opcode.mul_terms.push((scale, w_l, w_r));
165 }
166 (Some(x), Some(y)) => {
167 // This means that we can form a full opcode with this Qm term
168
169 // First fetch the left and right wires which match the mul term
170 let left_wire_term = opcode.linear_combinations[x];
171 let right_wire_term = opcode.linear_combinations[y];
172
173 // Lets create an intermediate opcode to store this full opcode
174 //
175 let mut intermediate_opcode = Expression::default();
176 intermediate_opcode.mul_terms.push((scale, w_l, w_r));
177
178 // Add the left and right wires
179 intermediate_opcode.linear_combinations.push(left_wire_term);
180 intermediate_opcode.linear_combinations.push(right_wire_term);
181 // Remove the left and right wires so we do not re-add them
182 match x.cmp(&y) {
183 Ordering::Greater => {
184 opcode.linear_combinations.remove(x);
185 opcode.linear_combinations.remove(y);
186 }
187 Ordering::Less => {
188 opcode.linear_combinations.remove(y);
189 opcode.linear_combinations.remove(x);
190 }
191 Ordering::Equal => {
192 opcode.linear_combinations.remove(x);
193 intermediate_opcode.linear_combinations.pop();
194 }
195 }
196
197 let used_space = intermediate_opcode.linear_combinations.len();
198 assert!(used_space < self.width);
199
200 // Now we have used up "used_space" spaces in our assert-zero opcode. The width now dictates how many more we can add
201 let mut remaining_space = self.width - used_space - 1; // We minus 1 because we need an extra space to contain the intermediate variable
202 // Keep adding terms until we have no more left, or we reach the width
203 let mut remaining_linear_terms =
204 Vec::with_capacity(opcode.linear_combinations.len());
205 while remaining_space > 0 {
206 if let Some(wire_term) = opcode.linear_combinations.pop() {
207 // Add this element into the new opcode
208 if self.solvable_witness.contains(&wire_term.1) {
209 intermediate_opcode.linear_combinations.push(wire_term);
210 remaining_space -= 1;
211 } else {
212 remaining_linear_terms.push(wire_term);
213 }
214 } else {
215 // No more usable elements left in the old opcode
216 break;
217 }
218 }
219 opcode.linear_combinations.extend(remaining_linear_terms);
220
221 // Constrain this intermediate_opcode to be equal to the temp variable by adding it into the IndexMap
222 // We need a unique name for our intermediate variable
223 // TODO(https://github.com/noir-lang/noir/issues/10192): Another optimization, which could be applied in another algorithm
224 // If two opcodes have a large fan-in/out and they share a few common terms, then we should create intermediate variables for them
225 // Do some sort of subset matching algorithm for this on the terms of the polynomial
226 let intermediate_var = Self::get_or_create_intermediate_var(
227 intermediate_variables,
228 intermediate_opcode,
229 num_witness,
230 );
231
232 // Add intermediate variable to the new opcode instead of the full opcode
233 self.mark_solvable(intermediate_var.1);
234 new_opcode.linear_combinations.push(intermediate_var);
235 }
236 }
237 }
238
239 // Add the rest of the elements back into the new_opcode
240 new_opcode.mul_terms.extend(remaining_mul_terms);
241 new_opcode.linear_combinations.extend(opcode.linear_combinations);
242 new_opcode.q_c = opcode.q_c;
243 new_opcode.sort();
244 new_opcode
245 }
246
247 /// Normalize an expression by dividing it by its first coefficient
248 /// The first coefficient here means coefficient of the first linear term, or of the first quadratic term if no linear terms exist.
249 /// This function panics if the input expression is constant or if the first coefficient's inverse is F::zero()
250 fn normalize<F: AcirField>(mut expr: Expression<F>) -> (F, Expression<F>) {
251 expr.sort();
252 let a = if !expr.linear_combinations.is_empty() {
253 expr.linear_combinations[0].0
254 } else {
255 expr.mul_terms[0].0
256 };
257 let a_inverse = a.inverse();
258 assert!(a_inverse != F::zero(), "normalize: the first coefficient is non-invertible");
259 (a, &expr * a_inverse)
260 }
261
262 /// Get or generate a scaled intermediate witness which is equal to the provided expression
263 /// The sets of previously generated witness and their (normalized) expression is cached in the intermediate_variables map
264 /// If there is no cache hit, we generate a new witness (and add the expression to the cache)
265 /// else, we return the cached witness along with the scaling factor so it is equal to the provided expression
266 fn get_or_create_intermediate_var<F: AcirField>(
267 intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
268 expr: Expression<F>,
269 num_witness: &mut u32,
270 ) -> (F, Witness) {
271 let (k, normalized_expr) = Self::normalize(expr);
272
273 if intermediate_variables.contains_key(&normalized_expr) {
274 let (l, iv) = intermediate_variables[&normalized_expr];
275 assert!(
276 l != F::zero(),
277 "get_or_create_intermediate_var: attempting to divide l by F::zero()"
278 );
279 (k / l, iv)
280 } else {
281 let inter_var = Witness(*num_witness);
282 *num_witness += 1;
283 // Add intermediate opcode and variable to map
284 intermediate_variables.insert(normalized_expr, (k, inter_var));
285 (F::one(), inter_var)
286 }
287 }
288
289 // A partial opcode scan optimization aim to create intermediate variables in order to compress the polynomial
290 // So that it fits within the given width
291 // Note that this opcode follows the full opcode scan optimization.
292 // We define the partial width as equal to the full width - 2.
293 // This is because two of our variables cannot be used as they are linked to the multiplication terms
294 // Example: qM1 * wL1 * wR2 + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
295 // One thing to note is that the multiplication wires do not match any of the fan-in/out wires. This is guaranteed as we have
296 // just completed the full opcode optimization algorithm.
297 //
298 // Actually we can optimize in two ways here: We can create an intermediate variable which is equal to the fan-in terms
299 // t = qL1 * wL3 + qR1 * wR4 -> width = 3
300 // This `t` value can only use width - 1 terms
301 // The opcode now looks like: qM1 * wL1 * wR2 + t + qR2 * wR5+ qO1 * wO5 + qC
302 // But this is still not acceptable since wR5 is not wR2, so we need another intermediate variable
303 // t2 = t + qR2 * wR5
304 //
305 // The opcode now looks like: qM1 * wL1 * wR2 + t2 + qO1 * wO5 + qC
306 // This is still not good, so we do it one more time:
307 // t3 = t2 + qO1 * wO5
308 // The opcode now looks like: qM1 * wL1 * wR2 + t3 + qC
309 //
310 // Another strategy is to create a temporary variable for the multiplier term and then we can see it as a term in the fan-in
311 //
312 // Same Example: qM1 * wL1 * wR2 + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
313 // t = qM1 * wL1 * wR2
314 // The opcode now looks like: t + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
315 // Still assuming width-3, we still need to use width-1 terms for the intermediate variables, however we can stop at an earlier stage because
316 // the opcode does not need the multiplier term to match with any of the fan-in terms
317 // t2 = t + qL1 * wL3
318 // The opcode now looks like: t2 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC
319 // t3 = t2 + qR1 * wR4
320 // The opcode now looks like: t3 + qR2 * wR5 + qO1 * wO5 + qC
321 // This took the same amount of opcodes, but which one is better when the width increases? Compute this and maybe do both optimizations
322 // naming : partial_opcode_mul_first_opt and partial_opcode_fan_first_opt
323 // Also remember that since we did full opcode scan, there is no way we can have a non-zero mul term along with the wL and wR terms being non-zero
324 //
325 // Cases, a lot of mul terms, a lot of fan-in terms, 50/50
326 fn partial_opcode_scan_optimization<F: AcirField>(
327 &mut self,
328 mut opcode: Expression<F>,
329 intermediate_variables: &mut IndexMap<Expression<F>, (F, Witness)>,
330 num_witness: &mut u32,
331 ) -> Expression<F> {
332 // We will go for the easiest route, which is to convert all multiplications into additions using intermediate variables
333 // Then use intermediate variables again to squash the fan-in, so that it can fit into the appropriate width
334
335 // First check if this polynomial actually needs a partial opcode optimization
336 // There is the chance that it fits perfectly within the assert-zero opcode
337 if fits_in_one_identity(&opcode, self.width) {
338 return opcode;
339 }
340
341 // Create Intermediate variables for the multiplication opcodes
342 let mut remaining_mul_terms = Vec::with_capacity(opcode.mul_terms.len());
343 for (scale, w_l, w_r) in opcode.mul_terms {
344 if self.solvable_witness.contains(&w_l) && self.solvable_witness.contains(&w_r) {
345 let mut intermediate_opcode = Expression::default();
346
347 // Push mul term into the opcode
348 intermediate_opcode.mul_terms.push((scale, w_l, w_r));
349 // Get an intermediate variable which squashes the multiplication term
350 let intermediate_var = Self::get_or_create_intermediate_var(
351 intermediate_variables,
352 intermediate_opcode,
353 num_witness,
354 );
355
356 // Add intermediate variable as a part of the fan-in for the original opcode
357 opcode.linear_combinations.push(intermediate_var);
358 self.mark_solvable(intermediate_var.1);
359 } else {
360 remaining_mul_terms.push((scale, w_l, w_r));
361 }
362 }
363
364 // Remove all of the mul terms as we have intermediate variables to represent them now
365 opcode.mul_terms = remaining_mul_terms;
366
367 // We now only have a polynomial with only fan-in/fan-out terms i.e. terms of the form Ax + By + Cd + ...
368 // Lets create intermediate variables if all of them cannot fit into the width
369 //
370 // If the polynomial fits perfectly within the given width, we are finished
371 if opcode.linear_combinations.len() <= self.width {
372 return opcode;
373 }
374
375 // Stores the intermediate variables that are used to
376 // reduce the fan in.
377 let mut added_vars = Vec::new();
378
379 while opcode.linear_combinations.len() > self.width {
380 // Collect as many terms up to the given width-1 and constrain them to an intermediate variable
381 let mut intermediate_opcode = Expression::default();
382
383 let mut remaining_linear_terms = Vec::with_capacity(opcode.linear_combinations.len());
384
385 for term in opcode.linear_combinations {
386 if self.solvable_witness.contains(&term.1)
387 && intermediate_opcode.linear_combinations.len() < self.width - 1
388 {
389 intermediate_opcode.linear_combinations.push(term);
390 } else {
391 remaining_linear_terms.push(term);
392 }
393 }
394 opcode.linear_combinations = remaining_linear_terms;
395 let not_full = intermediate_opcode.linear_combinations.len() < self.width - 1;
396 if intermediate_opcode.linear_combinations.len() > 1 {
397 let intermediate_var = Self::get_or_create_intermediate_var(
398 intermediate_variables,
399 intermediate_opcode,
400 num_witness,
401 );
402 self.mark_solvable(intermediate_var.1);
403 added_vars.push(intermediate_var);
404 } else {
405 // Put back the single term that couldn't form an intermediate variable
406 opcode.linear_combinations.extend(intermediate_opcode.linear_combinations);
407 }
408 // The intermediate opcode is not full, but the opcode still has too many terms
409 if not_full && opcode.linear_combinations.len() > self.width {
410 unreachable!("Could not reduce the expression");
411 }
412 }
413
414 // Add back the intermediate variables to
415 // keep consistency with the original equation.
416 opcode.linear_combinations.extend(added_vars);
417 self.partial_opcode_scan_optimization(opcode, intermediate_variables, num_witness)
418 }
419}
420
421/// Checks if this expression can fit into one arithmetic identity
422fn fits_in_one_identity<F: AcirField>(expr: &Expression<F>, width: usize) -> bool {
423 // A Polynomial with more than one mul term cannot fit into one opcode
424 if expr.mul_terms.len() > 1 {
425 return false;
426 }
427
428 expr.width() <= width
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use acir::FieldElement;
435
436 #[test]
437 fn simple_reduction_smoke_test() {
438 let a = Witness(0);
439 let b = Witness(1);
440 let c = Witness(2);
441 let d = Witness(3);
442
443 // a = b + c + d;
444 let opcode_a = Expression::from_str(&format!("{a} - {b} - {c} - {d}")).unwrap();
445
446 let mut intermediate_variables: IndexMap<
447 Expression<FieldElement>,
448 (FieldElement, Witness),
449 > = IndexMap::new();
450
451 let mut num_witness = 4;
452
453 let mut optimizer = CSatTransformer::new(3);
454 optimizer.mark_solvable(b);
455 optimizer.mark_solvable(c);
456 optimizer.mark_solvable(d);
457 let got_optimized_opcode_a =
458 optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness);
459
460 // a = b + c + d => a - b - c - d = 0
461 // For width3, the result becomes:
462 // a - d + e = 0
463 // - c - b - e = 0
464 //
465 // a - b + e = 0
466 let e = Witness(4);
467 let expected_optimized_opcode_a =
468 Expression::from_str(&format!("{a} - {d} + {e}")).unwrap();
469
470 assert_eq!(expected_optimized_opcode_a, got_optimized_opcode_a);
471
472 assert_eq!(intermediate_variables.len(), 1);
473
474 // e = - c - b
475 let expected_intermediate_opcode = Expression::from_str(&format!("-{c} - {b}")).unwrap();
476 let (_, normalized_opcode) = CSatTransformer::normalize(expected_intermediate_opcode);
477 assert!(intermediate_variables.contains_key(&normalized_opcode));
478 assert_eq!(intermediate_variables[&normalized_opcode].1, e);
479 }
480
481 #[test]
482 fn stepwise_reduction_test() {
483 let a = Witness(0);
484 let b = Witness(1);
485 let c = Witness(2);
486 let d = Witness(3);
487 let e = Witness(4);
488
489 // a = b + c + d + e;
490 let opcode_a = Expression::from_str(&format!("-{a} + {b} + {c} + {d} + {e}")).unwrap();
491
492 let mut intermediate_variables: IndexMap<
493 Expression<FieldElement>,
494 (FieldElement, Witness),
495 > = IndexMap::new();
496
497 let mut num_witness = 4;
498
499 let mut optimizer = CSatTransformer::new(3);
500 optimizer.mark_solvable(a);
501 optimizer.mark_solvable(c);
502 optimizer.mark_solvable(d);
503 optimizer.mark_solvable(e);
504 let got_optimized_opcode_a =
505 optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness);
506
507 // Since b is not known, it cannot be put inside intermediate opcodes, so it must belong to the transformed opcode.
508 let contains_b = got_optimized_opcode_a.linear_combinations.iter().any(|(_, w)| *w == b);
509 assert!(contains_b);
510 }
511
512 #[test]
513 fn recognize_expr_with_single_shared_witness_which_fits_in_single_identity() {
514 // Regression test for an expression which Zac found which should have been preserved but
515 // was being split into two expressions.
516 let expr = Expression::from_str("-555*w8*w10 + w10 + w11 - w13").unwrap();
517 assert!(fits_in_one_identity(&expr, 4));
518 }
519
520 #[test]
521 #[should_panic(expected = "normalize: the first coefficient is non-invertible")]
522 fn normalize_on_zero_linear_combination_panics() {
523 let expr = Expression {
524 mul_terms: vec![],
525 linear_combinations: vec![(FieldElement::zero(), Witness(0))],
526 q_c: FieldElement::zero(),
527 };
528 CSatTransformer::normalize(expr);
529 }
530
531 #[test]
532 #[should_panic(expected = "normalize: the first coefficient is non-invertible")]
533 fn normalize_on_zero_mul_term_scale_panics() {
534 let expr = Expression {
535 mul_terms: vec![(FieldElement::zero(), Witness(0), Witness(1))],
536 linear_combinations: vec![],
537 q_c: FieldElement::zero(),
538 };
539 CSatTransformer::normalize(expr);
540 }
541
542 #[test]
543 #[should_panic(
544 expected = "get_or_create_intermediate_var: attempting to divide l by F::zero()"
545 )]
546 fn get_or_create_intermediate_var_with_zero_panics() {
547 let expr = Expression {
548 mul_terms: vec![(FieldElement::one(), Witness(0), Witness(1))],
549 linear_combinations: vec![],
550 q_c: FieldElement::zero(),
551 };
552
553 let mut intermediate_variables = IndexMap::new();
554 intermediate_variables.insert(expr.clone(), (FieldElement::zero(), Witness(0)));
555
556 let mut num_witness = 2;
557
558 CSatTransformer::get_or_create_intermediate_var(
559 &mut intermediate_variables,
560 expr,
561 &mut num_witness,
562 );
563 }
564
565 #[test]
566 fn full_opcode_scan_optimization_extracts_full_opcodes() {
567 // Expression: x*x + a*b + x + a + b + c + d + e + f + g
568 //
569 // With width=3 and two mul terms, full_opcode_scan_optimization extracts each
570 // mul term together with 2 linear terms into an intermediate variable:
571 // t0 = x*x + x + g (Witness 8)
572 // t1 = a*b + a + b (Witness 9)
573 //
574 // The remaining expression becomes: t0 + t1 + c + d + e + f
575 let x = Witness(0);
576 let a = Witness(1);
577 let b = Witness(2);
578 let c = Witness(3);
579 let d = Witness(4);
580 let e = Witness(5);
581 let f = Witness(6);
582 let g = Witness(7);
583
584 let opcode = Expression {
585 mul_terms: vec![(FieldElement::one(), x, x), (FieldElement::one(), a, b)],
586 linear_combinations: vec![
587 (FieldElement::one(), x),
588 (FieldElement::one(), a),
589 (FieldElement::one(), b),
590 (FieldElement::one(), c),
591 (FieldElement::one(), d),
592 (FieldElement::one(), e),
593 (FieldElement::one(), f),
594 (FieldElement::one(), g),
595 ],
596 q_c: FieldElement::zero(),
597 };
598
599 let mut intermediate_variables: IndexMap<
600 Expression<FieldElement>,
601 (FieldElement, Witness),
602 > = IndexMap::new();
603 let mut num_witness = 8u32;
604
605 let mut optimizer = CSatTransformer::new(3);
606 for w in [x, a, b, c, d, e, f, g] {
607 optimizer.mark_solvable(w);
608 }
609
610 let result = optimizer.full_opcode_scan_optimization(
611 opcode,
612 &mut intermediate_variables,
613 &mut num_witness,
614 );
615
616 // Both mul terms were replaced by intermediate variables; no mul terms remain.
617 assert!(result.mul_terms.is_empty(), "all mul terms should be absorbed");
618
619 // Each intermediate variable is full: it has 2 linear terms.
620 for intermediate in intermediate_variables.keys() {
621 assert!(
622 intermediate.mul_terms.len() == 1,
623 "intermediate variables should be full opcodes"
624 );
625 assert!(
626 intermediate.linear_combinations.len() == 2,
627 "intermediate variables should be full opcodes"
628 );
629 }
630 // Remaining linear terms: c, d, e, f+ t0, t1 (intermediate vars).
631 assert_eq!(result.linear_combinations.len(), 6);
632 }
633
634 #[test]
635 #[should_panic(expected = "Could not reduce the expression")]
636 fn single_solvable_term_in_intermediate_opcode_is_preserved() {
637 // Test the case when len() is 1 in the line: 'if intermediate_opcode.linear_combinations.len() > 1 {'
638 // Setup: width is 3, 4 terms [a, b, c, d], only 'a' is solvable.
639 //
640 // Because only 'a' is solvable, the intermediate_opcode will be [a], so its len is 1.
641 // In this case, 'a' should be added back to the opcode which makes it larger than the width and trigger the panic.
642
643 let a = Witness(0); // solvable
644 let b = Witness(1); // unsolvable
645 let c = Witness(2); // unsolvable
646 let d = Witness(3); // unsolvable
647
648 let opcode = Expression {
649 mul_terms: vec![],
650 linear_combinations: vec![
651 (FieldElement::one(), a),
652 (FieldElement::one(), b),
653 (FieldElement::one(), c),
654 (FieldElement::one(), d),
655 ],
656 q_c: FieldElement::zero(),
657 };
658
659 let mut intermediate_variables: IndexMap<
660 Expression<FieldElement>,
661 (FieldElement, Witness),
662 > = IndexMap::new();
663
664 let mut num_witness = 4;
665
666 let mut optimizer = CSatTransformer::new(3);
667 optimizer.mark_solvable(a);
668
669 let _ = optimizer.transform(opcode, &mut intermediate_variables, &mut num_witness);
670 }
671}