acir/native_types/expression/
mod.rs1use crate::{circuit::PublicInputs, native_types::Witness};
2use acir_field::AcirField;
3use serde::{Deserialize, Serialize};
4use std::cmp::Ordering;
5mod operators;
6
7#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
24#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
25pub struct Expression<F> {
26 pub mul_terms: Vec<(F, Witness, Witness)>,
33
34 pub linear_combinations: Vec<(F, Witness)>,
39 pub q_c: F,
43}
44
45impl<F: AcirField> Default for Expression<F> {
46 fn default() -> Self {
47 Expression { mul_terms: Vec::new(), linear_combinations: Vec::new(), q_c: F::zero() }
48 }
49}
50
51impl<F: AcirField> std::fmt::Display for Expression<F> {
52 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53 display_expression(self, false, None, f)
54 }
55}
56
57impl<F> Expression<F> {
58 pub fn num_mul_terms(&self) -> usize {
60 self.mul_terms.len()
61 }
62
63 pub fn push_addition_term(&mut self, coefficient: F, variable: Witness) {
65 self.linear_combinations.push((coefficient, variable));
66 }
67
68 pub fn push_multiplication_term(&mut self, coefficient: F, lhs: Witness, rhs: Witness) {
70 self.mul_terms.push((coefficient, lhs, rhs));
71 }
72
73 pub fn is_const(&self) -> bool {
80 self.mul_terms.is_empty() && self.linear_combinations.is_empty()
81 }
82
83 pub fn to_const(&self) -> Option<&F> {
93 self.is_const().then_some(&self.q_c)
94 }
95
96 pub fn is_linear(&self) -> bool {
108 self.mul_terms.is_empty()
109 }
110
111 pub fn is_degree_one_univariate(&self) -> bool {
131 self.is_linear() && self.linear_combinations.len() == 1
132 }
133
134 pub fn sort(&mut self) {
138 for term in &mut self.mul_terms {
139 if term.1 > term.2 {
140 std::mem::swap(&mut term.1, &mut term.2);
141 }
142 }
143 self.mul_terms.sort_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)));
144 self.linear_combinations.sort_by(|a, b| a.1.cmp(&b.1));
145 }
146
147 #[cfg(test)]
148 pub(crate) fn is_sorted(&self) -> bool {
149 self.mul_terms.iter().all(|term| term.1 <= term.2)
150 && self.mul_terms.iter().is_sorted_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)).is_le())
151 && self.linear_combinations.iter().is_sorted_by(|a, b| a.1.cmp(&b.1).is_le())
152 }
153}
154
155impl<F: AcirField> Expression<F> {
156 pub fn from_field(q_c: F) -> Self {
157 Self { q_c, ..Default::default() }
158 }
159
160 pub fn zero() -> Self {
161 Self::default()
162 }
163
164 pub fn is_zero(&self) -> bool {
165 *self == Self::zero()
166 }
167
168 pub fn one() -> Self {
169 Self::from_field(F::one())
170 }
171
172 pub fn is_one(&self) -> bool {
173 *self == Self::one()
174 }
175
176 pub fn to_witness(&self) -> Option<Witness> {
183 if self.is_degree_one_univariate() {
184 let (coefficient, variable) = self.linear_combinations[0];
189 let constant = self.q_c;
190
191 if coefficient.is_one() && constant.is_zero() {
192 return Some(variable);
193 }
194 }
195 None
196 }
197
198 pub fn add_mul(&self, k: F, b: &Self) -> Self {
200 if k.is_zero() {
201 return self.clone();
202 } else if self.is_const() {
203 let kb = b * k;
204 return kb + self.q_c;
205 } else if b.is_const() {
206 return self.clone() + (k * b.q_c);
207 }
208
209 let mut mul_terms: Vec<(F, Witness, Witness)> =
210 Vec::with_capacity(self.mul_terms.len() + b.mul_terms.len());
211 let mut linear_combinations: Vec<(F, Witness)> =
212 Vec::with_capacity(self.linear_combinations.len() + b.linear_combinations.len());
213 let q_c = self.q_c + k * b.q_c;
214
215 let mut i1 = 0; let mut i2 = 0; while i1 < self.linear_combinations.len() && i2 < b.linear_combinations.len() {
219 let (a_c, a_w) = self.linear_combinations[i1];
220 let (b_c, b_w) = b.linear_combinations[i2];
221
222 let (coeff, witness) = match a_w.cmp(&b_w) {
223 Ordering::Greater => {
224 i2 += 1;
225 (k * b_c, b_w)
226 }
227 Ordering::Less => {
228 i1 += 1;
229 (a_c, a_w)
230 }
231 Ordering::Equal => {
232 i1 += 1;
235 i2 += 1;
236 (a_c + k * b_c, a_w)
237 }
238 };
239
240 if !coeff.is_zero() {
241 linear_combinations.push((coeff, witness));
242 }
243 }
244
245 while i1 < self.linear_combinations.len() {
247 linear_combinations.push(self.linear_combinations[i1]);
248 i1 += 1;
249 }
250 while i2 < b.linear_combinations.len() {
251 let (b_c, b_w) = b.linear_combinations[i2];
252 let coeff = b_c * k;
253 if !coeff.is_zero() {
254 linear_combinations.push((coeff, b_w));
255 }
256 i2 += 1;
257 }
258
259 i1 = 0; i2 = 0; while i1 < self.mul_terms.len() && i2 < b.mul_terms.len() {
264 let (a_c, a_wl, a_wr) = self.mul_terms[i1];
265 let (b_c, b_wl, b_wr) = b.mul_terms[i2];
266
267 let (coeff, wl, wr) = match (a_wl, a_wr).cmp(&(b_wl, b_wr)) {
268 Ordering::Greater => {
269 i2 += 1;
270 (k * b_c, b_wl, b_wr)
271 }
272 Ordering::Less => {
273 i1 += 1;
274 (a_c, a_wl, a_wr)
275 }
276 Ordering::Equal => {
277 i2 += 1;
280 i1 += 1;
281 (a_c + k * b_c, a_wl, a_wr)
282 }
283 };
284
285 if !coeff.is_zero() {
286 mul_terms.push((coeff, wl, wr));
287 }
288 }
289
290 while i1 < self.mul_terms.len() {
292 mul_terms.push(self.mul_terms[i1]);
293 i1 += 1;
294 }
295 while i2 < b.mul_terms.len() {
296 let (b_c, b_wl, b_wr) = b.mul_terms[i2];
297 let coeff = b_c * k;
298 if coeff != F::zero() {
299 mul_terms.push((coeff, b_wl, b_wr));
300 }
301 i2 += 1;
302 }
303
304 Expression { mul_terms, linear_combinations, q_c }
305 }
306
307 pub fn width(&self) -> usize {
314 if self.mul_terms.len() > 1 {
315 unimplemented!("ICE - width() does not support expressions with multiple mul terms");
316 }
317
318 let mut width = 0;
319
320 for mul_term in &self.mul_terms {
321 assert_ne!(mul_term.0, F::zero());
323
324 let mut found_x = false;
325 let mut found_y = false;
326
327 for term in &self.linear_combinations {
328 let witness = &term.1;
329 let x = &mul_term.1;
330 let y = &mul_term.2;
331 if witness == x {
332 found_x = true;
333 }
334 if witness == y {
335 found_y = true;
336 }
337 if found_x & found_y {
338 break;
339 }
340 }
341
342 let multiplication_is_squaring = mul_term.1 == mul_term.2;
345
346 let mul_term_width_contribution = if !multiplication_is_squaring && (found_x & found_y)
347 {
348 0
352 } else if found_x || found_y {
353 1
356 } else {
357 2
359 };
360
361 width += mul_term_width_contribution;
362 }
363
364 width += self.linear_combinations.len();
365
366 width
367 }
368}
369
370impl<F: AcirField> From<F> for Expression<F> {
371 fn from(constant: F) -> Self {
372 Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() }
373 }
374}
375
376impl<F: AcirField> From<Witness> for Expression<F> {
377 fn from(wit: Witness) -> Self {
383 Expression {
384 q_c: F::zero(),
385 linear_combinations: vec![(F::one(), wit)],
386 mul_terms: Vec::new(),
387 }
388 }
389}
390
391pub(crate) fn display_expression<F: AcirField>(
397 expr: &Expression<F>,
398 as_equal_to_zero: bool,
399 return_values: Option<&PublicInputs>,
400 f: &mut std::fmt::Formatter<'_>,
401) -> std::fmt::Result {
402 let mut assignment_witness: Option<usize> = None;
405
406 let mut negate_coefficients = false;
410
411 let linear_witness_one = if as_equal_to_zero {
414 let linear_witness_one = return_values.and_then(|return_values| {
416 expr.linear_combinations.iter().enumerate().find(|(_, (coefficient, witness))| {
417 (coefficient.is_one() || (-*coefficient).is_one())
418 && return_values.0.contains(witness)
419 })
420 });
421 linear_witness_one.or_else(|| {
422 expr.linear_combinations
424 .iter()
425 .enumerate()
426 .filter(|(_, (coefficient, _))| coefficient.is_one() || (-*coefficient).is_one())
427 .max_by_key(|(_, (_, witness))| witness)
428 })
429 } else {
430 None
431 };
432
433 if let Some((index, (coefficient, witness))) = linear_witness_one {
436 assignment_witness = Some(index);
437 negate_coefficients = coefficient.is_one();
438 write!(f, "{witness} = ")?;
439 } else if as_equal_to_zero {
440 write!(f, "0 = ")?;
441 }
442
443 let mut printed_term = false;
444
445 for (coefficient, witness1, witness2) in &expr.mul_terms {
446 let witnesses = [*witness1, *witness2];
447 display_term(*coefficient, witnesses, printed_term, negate_coefficients, f)?;
448 printed_term = true;
449 }
450
451 for (index, (coefficient, witness)) in expr.linear_combinations.iter().enumerate() {
452 if assignment_witness
453 .is_some_and(|show_as_assignment_index| show_as_assignment_index == index)
454 {
455 continue;
457 }
458
459 let witnesses = [*witness];
460 display_term(*coefficient, witnesses, printed_term, negate_coefficients, f)?;
461 printed_term = true;
462 }
463
464 if expr.q_c.is_zero() {
465 if !printed_term {
466 write!(f, "0")?;
467 }
468 } else {
469 let coefficient = expr.q_c;
470 let coefficient = if negate_coefficients { -coefficient } else { coefficient };
471 let coefficient_as_string = coefficient.to_string();
472 let coefficient_is_negative = coefficient_as_string.starts_with('-');
473
474 if printed_term {
475 if coefficient_is_negative {
476 write!(f, " - ")?;
477 } else {
478 write!(f, " + ")?;
479 }
480 }
481
482 let coefficient =
483 if printed_term && coefficient_is_negative { -coefficient } else { coefficient };
484 write!(f, "{coefficient}")?;
485 }
486
487 Ok(())
488}
489
490fn display_term<F: AcirField, const N: usize>(
491 coefficient: F,
492 witnesses: [Witness; N],
493 printed_term: bool,
494 negate_coefficients: bool,
495 f: &mut std::fmt::Formatter<'_>,
496) -> std::fmt::Result {
497 let coefficient = if negate_coefficients { -coefficient } else { coefficient };
498 let coefficient_as_string = coefficient.to_string();
499 let coefficient_is_negative = coefficient_as_string.starts_with('-');
500
501 if printed_term {
502 if coefficient_is_negative {
503 write!(f, " - ")?;
504 } else {
505 write!(f, " + ")?;
506 }
507 }
508
509 let coefficient =
510 if printed_term && coefficient_is_negative { -coefficient } else { coefficient };
511
512 if coefficient.is_one() {
513 } else if (-coefficient).is_one() {
515 write!(f, "-")?;
516 } else {
517 write!(f, "{coefficient}*")?;
518 }
519
520 for (index, witness) in witnesses.iter().enumerate() {
521 if index != 0 {
522 write!(f, "*")?;
523 }
524 write!(f, "{witness}")?;
525 }
526
527 Ok(())
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use acir_field::FieldElement;
534
535 #[test]
536 fn add_mul_smoke_test() {
537 let a = Expression::from_str("2*w1*w2").unwrap();
538
539 let k = FieldElement::from(10u128);
540 let b = Expression::from_str("3*w0*w2 + 3*w1*w2 + 4*w4*w5 + 4*w4 + 1").unwrap();
541
542 let result = a.add_mul(k, &b);
543 assert_eq!(result.to_string(), "30*w0*w2 + 32*w1*w2 + 40*w4*w5 + 40*w4 + 10");
544 }
545
546 #[test]
547 fn add_mul_with_zero_coefficient() {
548 let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap();
550 let b = Expression::from_str("4*w2*w3 + 6*w2 + 7").unwrap();
551 let k = FieldElement::zero();
552
553 let result = a.add_mul(k, &b);
554 assert_eq!(result, a);
555 }
556
557 #[test]
558 fn add_mul_when_self_is_const() {
559 let a = Expression::from_field(FieldElement::from(5u128));
561 let b = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
562 let k = FieldElement::from(2u128);
563
564 let result = a.add_mul(k, &b);
565 assert_eq!(result.to_string(), "4*w1*w2 + 6*w1 + 13");
566 }
567
568 #[test]
569 fn add_mul_when_b_is_const() {
570 let a = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
572 let b = Expression::from_field(FieldElement::from(5u128));
573 let k = FieldElement::from(3u128);
574
575 let result = a.add_mul(k, &b);
576 assert_eq!(result.to_string(), "2*w1*w2 + 3*w1 + 19");
577 }
578
579 #[test]
580 fn add_mul_merges_linear_terms() {
581 let a = Expression::from_str("5*w1 + 3*w2").unwrap();
583 let b = Expression::from_str("2*w1 + 4*w3").unwrap();
584 let k = FieldElement::from(2u128);
585
586 let result = a.add_mul(k, &b);
587 assert_eq!(result.to_string(), "9*w1 + 3*w2 + 8*w3");
589 }
590
591 #[test]
592 fn add_mul_merges_mul_terms() {
593 let a = Expression::from_str("5*w1*w2 + 3*w3*w4").unwrap();
595 let b = Expression::from_str("2*w1*w2 + 4*w5*w6").unwrap();
596 let k = FieldElement::from(3u128);
597
598 let result = a.add_mul(k, &b);
599 assert_eq!(result.to_string(), "11*w1*w2 + 3*w3*w4 + 12*w5*w6");
602 }
603
604 #[test]
605 fn add_mul_cancels_terms_to_zero() {
606 let a = Expression::from_str("6*w1 + 3*w1*w2").unwrap();
608 let b = Expression::from_str("3*w1 + w1*w2").unwrap();
609 let k = FieldElement::from(-2i128);
610
611 let result = a.add_mul(k, &b);
612 assert_eq!(result.to_string(), "w1*w2");
615 }
616
617 #[test]
618 fn add_mul_maintains_sorted_order() {
619 let a = Expression::from_str("w5 + w1*w3").unwrap();
621 let b = Expression::from_str("w2 + w0*w1").unwrap();
622 let k = FieldElement::one();
623
624 let result = a.add_mul(k, &b);
625 assert!(result.is_sorted());
627 assert_eq!(result.to_string(), "w0*w1 + w1*w3 + w2 + w5");
628 }
629
630 #[test]
631 fn add_mul_with_constant_terms() {
632 let a = Expression::from_str("2*w1 + 10").unwrap();
634 let b = Expression::from_str("3*w2 + 5").unwrap();
635 let k = FieldElement::from(4u128);
636
637 let result = a.add_mul(k, &b);
638 assert_eq!(result.to_string(), "2*w1 + 12*w2 + 30");
640 }
641
642 #[test]
643 fn add_mul_complex_expression() {
644 let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11").unwrap();
646 let b = Expression::from_str("w1*w2 + 4*w5*w6 + 2*w1 + 6*w5 + 13").unwrap();
647 let k = FieldElement::from(2u128);
648
649 let result = a.add_mul(k, &b);
650 assert_eq!(result.to_string(), "4*w1*w2 + 3*w3*w4 + 8*w5*w6 + 9*w1 + 7*w3 + 12*w5 + 37");
654 }
655
656 #[test]
657 fn display_zero() {
658 let zero = Expression::<FieldElement>::default();
659 assert_eq!(zero.to_string(), "0");
660 }
661}