acir/native_types/expression/
mod.rs1use crate::{circuit::PublicInputs, native_types::Witness};
2use acir_field::AcirField;
3use serde::{Deserialize, Serialize};
4use std::cmp::Ordering;
5mod operators;
6
7#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
24#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
25pub struct Expression<F> {
26 pub mul_terms: Vec<(F, Witness, Witness)>,
33
34 pub linear_combinations: Vec<(F, Witness)>,
39 pub q_c: F,
43}
44
45impl<F: AcirField> Default for Expression<F> {
46 fn default() -> Self {
47 Expression { mul_terms: Vec::new(), linear_combinations: Vec::new(), q_c: F::zero() }
48 }
49}
50
51impl<F: AcirField> std::fmt::Display for Expression<F> {
52 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53 display_expression(self, false, None, f)
54 }
55}
56
57impl<F> Expression<F> {
58 pub fn num_mul_terms(&self) -> usize {
60 self.mul_terms.len()
61 }
62
63 pub fn push_addition_term(&mut self, coefficient: F, variable: Witness) {
65 self.linear_combinations.push((coefficient, variable));
66 }
67
68 pub fn push_multiplication_term(&mut self, coefficient: F, lhs: Witness, rhs: Witness) {
70 self.mul_terms.push((coefficient, lhs, rhs));
71 }
72
73 pub fn is_const(&self) -> bool {
80 self.mul_terms.is_empty() && self.linear_combinations.is_empty()
81 }
82
83 pub fn to_const(&self) -> Option<&F> {
93 self.is_const().then_some(&self.q_c)
94 }
95
96 pub fn is_linear(&self) -> bool {
108 self.mul_terms.is_empty()
109 }
110
111 pub fn is_degree_one_univariate(&self) -> bool {
131 self.is_linear() && self.linear_combinations.len() == 1
132 }
133
134 pub fn sort(&mut self) {
137 self.mul_terms.sort_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)));
138 self.linear_combinations.sort_by(|a, b| a.1.cmp(&b.1));
139 }
140
141 #[cfg(test)]
142 pub(crate) fn is_sorted(&self) -> bool {
143 self.mul_terms.iter().is_sorted_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)).is_le())
144 && self.linear_combinations.iter().is_sorted_by(|a, b| a.1.cmp(&b.1).is_le())
145 }
146}
147
148impl<F: AcirField> Expression<F> {
149 pub fn from_field(q_c: F) -> Self {
150 Self { q_c, ..Default::default() }
151 }
152
153 pub fn zero() -> Self {
154 Self::default()
155 }
156
157 pub fn is_zero(&self) -> bool {
158 *self == Self::zero()
159 }
160
161 pub fn one() -> Self {
162 Self::from_field(F::one())
163 }
164
165 pub fn is_one(&self) -> bool {
166 *self == Self::one()
167 }
168
169 pub fn to_witness(&self) -> Option<Witness> {
176 if self.is_degree_one_univariate() {
177 let (coefficient, variable) = self.linear_combinations[0];
182 let constant = self.q_c;
183
184 if coefficient.is_one() && constant.is_zero() {
185 return Some(variable);
186 }
187 }
188 None
189 }
190
191 pub fn add_mul(&self, k: F, b: &Self) -> Self {
193 if k.is_zero() {
194 return self.clone();
195 } else if self.is_const() {
196 let kb = b * k;
197 return kb + self.q_c;
198 } else if b.is_const() {
199 return self.clone() + (k * b.q_c);
200 }
201
202 let mut mul_terms: Vec<(F, Witness, Witness)> =
203 Vec::with_capacity(self.mul_terms.len() + b.mul_terms.len());
204 let mut linear_combinations: Vec<(F, Witness)> =
205 Vec::with_capacity(self.linear_combinations.len() + b.linear_combinations.len());
206 let q_c = self.q_c + k * b.q_c;
207
208 let mut i1 = 0; let mut i2 = 0; while i1 < self.linear_combinations.len() && i2 < b.linear_combinations.len() {
212 let (a_c, a_w) = self.linear_combinations[i1];
213 let (b_c, b_w) = b.linear_combinations[i2];
214
215 let (coeff, witness) = match a_w.cmp(&b_w) {
216 Ordering::Greater => {
217 i2 += 1;
218 (k * b_c, b_w)
219 }
220 Ordering::Less => {
221 i1 += 1;
222 (a_c, a_w)
223 }
224 Ordering::Equal => {
225 i1 += 1;
228 i2 += 1;
229 (a_c + k * b_c, a_w)
230 }
231 };
232
233 if !coeff.is_zero() {
234 linear_combinations.push((coeff, witness));
235 }
236 }
237
238 while i1 < self.linear_combinations.len() {
240 linear_combinations.push(self.linear_combinations[i1]);
241 i1 += 1;
242 }
243 while i2 < b.linear_combinations.len() {
244 let (b_c, b_w) = b.linear_combinations[i2];
245 let coeff = b_c * k;
246 if !coeff.is_zero() {
247 linear_combinations.push((coeff, b_w));
248 }
249 i2 += 1;
250 }
251
252 i1 = 0; i2 = 0; while i1 < self.mul_terms.len() && i2 < b.mul_terms.len() {
257 let (a_c, a_wl, a_wr) = self.mul_terms[i1];
258 let (b_c, b_wl, b_wr) = b.mul_terms[i2];
259
260 let (coeff, wl, wr) = match (a_wl, a_wr).cmp(&(b_wl, b_wr)) {
261 Ordering::Greater => {
262 i2 += 1;
263 (k * b_c, b_wl, b_wr)
264 }
265 Ordering::Less => {
266 i1 += 1;
267 (a_c, a_wl, a_wr)
268 }
269 Ordering::Equal => {
270 i2 += 1;
273 i1 += 1;
274 (a_c + k * b_c, a_wl, a_wr)
275 }
276 };
277
278 if !coeff.is_zero() {
279 mul_terms.push((coeff, wl, wr));
280 }
281 }
282
283 while i1 < self.mul_terms.len() {
285 mul_terms.push(self.mul_terms[i1]);
286 i1 += 1;
287 }
288 while i2 < b.mul_terms.len() {
289 let (b_c, b_wl, b_wr) = b.mul_terms[i2];
290 let coeff = b_c * k;
291 if coeff != F::zero() {
292 mul_terms.push((coeff, b_wl, b_wr));
293 }
294 i2 += 1;
295 }
296
297 Expression { mul_terms, linear_combinations, q_c }
298 }
299
300 pub fn width(&self) -> usize {
303 let mut width = 0;
304
305 for mul_term in &self.mul_terms {
306 assert_ne!(mul_term.0, F::zero());
308
309 let mut found_x = false;
310 let mut found_y = false;
311
312 for term in &self.linear_combinations {
313 let witness = &term.1;
314 let x = &mul_term.1;
315 let y = &mul_term.2;
316 if witness == x {
317 found_x = true;
318 }
319 if witness == y {
320 found_y = true;
321 }
322 if found_x & found_y {
323 break;
324 }
325 }
326
327 let multiplication_is_squaring = mul_term.1 == mul_term.2;
330
331 let mul_term_width_contribution = if !multiplication_is_squaring && (found_x & found_y)
332 {
333 0
337 } else if found_x || found_y {
338 1
341 } else {
342 2
344 };
345
346 width += mul_term_width_contribution;
347 }
348
349 width += self.linear_combinations.len();
350
351 width
352 }
353}
354
355impl<F: AcirField> From<F> for Expression<F> {
356 fn from(constant: F) -> Self {
357 Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() }
358 }
359}
360
361impl<F: AcirField> From<Witness> for Expression<F> {
362 fn from(wit: Witness) -> Self {
368 Expression {
369 q_c: F::zero(),
370 linear_combinations: vec![(F::one(), wit)],
371 mul_terms: Vec::new(),
372 }
373 }
374}
375
376pub(crate) fn display_expression<F: AcirField>(
382 expr: &Expression<F>,
383 as_equal_to_zero: bool,
384 return_values: Option<&PublicInputs>,
385 f: &mut std::fmt::Formatter<'_>,
386) -> std::fmt::Result {
387 let mut assignment_witness: Option<usize> = None;
390
391 let mut negate_coefficients = false;
395
396 let linear_witness_one = if as_equal_to_zero {
399 let linear_witness_one = return_values.and_then(|return_values| {
401 expr.linear_combinations.iter().enumerate().find(|(_, (coefficient, witness))| {
402 (coefficient.is_one() || (-*coefficient).is_one())
403 && return_values.0.contains(witness)
404 })
405 });
406 linear_witness_one.or_else(|| {
407 expr.linear_combinations
409 .iter()
410 .enumerate()
411 .filter(|(_, (coefficient, _))| coefficient.is_one() || (-*coefficient).is_one())
412 .max_by_key(|(_, (_, witness))| witness)
413 })
414 } else {
415 None
416 };
417
418 if let Some((index, (coefficient, witness))) = linear_witness_one {
421 assignment_witness = Some(index);
422 negate_coefficients = coefficient.is_one();
423 write!(f, "{witness} = ")?;
424 } else if as_equal_to_zero {
425 write!(f, "0 = ")?;
426 }
427
428 let mut printed_term = false;
429
430 for (coefficient, witness1, witness2) in &expr.mul_terms {
431 let witnesses = [*witness1, *witness2];
432 display_term(*coefficient, witnesses, printed_term, negate_coefficients, f)?;
433 printed_term = true;
434 }
435
436 for (index, (coefficient, witness)) in expr.linear_combinations.iter().enumerate() {
437 if assignment_witness
438 .is_some_and(|show_as_assignment_index| show_as_assignment_index == index)
439 {
440 continue;
442 }
443
444 let witnesses = [*witness];
445 display_term(*coefficient, witnesses, printed_term, negate_coefficients, f)?;
446 printed_term = true;
447 }
448
449 if expr.q_c.is_zero() {
450 if !printed_term {
451 write!(f, "0")?;
452 }
453 } else {
454 let coefficient = expr.q_c;
455 let coefficient = if negate_coefficients { -coefficient } else { coefficient };
456 let coefficient_as_string = coefficient.to_string();
457 let coefficient_is_negative = coefficient_as_string.starts_with('-');
458
459 if printed_term {
460 if coefficient_is_negative {
461 write!(f, " - ")?;
462 } else {
463 write!(f, " + ")?;
464 }
465 }
466
467 let coefficient =
468 if printed_term && coefficient_is_negative { -coefficient } else { coefficient };
469 write!(f, "{coefficient}")?;
470 }
471
472 Ok(())
473}
474
475fn display_term<F: AcirField, const N: usize>(
476 coefficient: F,
477 witnesses: [Witness; N],
478 printed_term: bool,
479 negate_coefficients: bool,
480 f: &mut std::fmt::Formatter<'_>,
481) -> std::fmt::Result {
482 let coefficient = if negate_coefficients { -coefficient } else { coefficient };
483 let coefficient_as_string = coefficient.to_string();
484 let coefficient_is_negative = coefficient_as_string.starts_with('-');
485
486 if printed_term {
487 if coefficient_is_negative {
488 write!(f, " - ")?;
489 } else {
490 write!(f, " + ")?;
491 }
492 }
493
494 let coefficient =
495 if printed_term && coefficient_is_negative { -coefficient } else { coefficient };
496
497 if coefficient.is_one() {
498 } else if (-coefficient).is_one() {
500 write!(f, "-")?;
501 } else {
502 write!(f, "{coefficient}*")?;
503 }
504
505 for (index, witness) in witnesses.iter().enumerate() {
506 if index != 0 {
507 write!(f, "*")?;
508 }
509 write!(f, "{witness}")?;
510 }
511
512 Ok(())
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use acir_field::FieldElement;
519
520 #[test]
521 fn add_mul_smoke_test() {
522 let a = Expression::from_str("2*w1*w2").unwrap();
523
524 let k = FieldElement::from(10u128);
525 let b = Expression::from_str("3*w0*w2 + 3*w1*w2 + 4*w4*w5 + 4*w4 + 1").unwrap();
526
527 let result = a.add_mul(k, &b);
528 assert_eq!(result.to_string(), "30*w0*w2 + 32*w1*w2 + 40*w4*w5 + 40*w4 + 10");
529 }
530
531 #[test]
532 fn add_mul_with_zero_coefficient() {
533 let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap();
535 let b = Expression::from_str("4*w2*w3 + 6*w2 + 7").unwrap();
536 let k = FieldElement::zero();
537
538 let result = a.add_mul(k, &b);
539 assert_eq!(result, a);
540 }
541
542 #[test]
543 fn add_mul_when_self_is_const() {
544 let a = Expression::from_field(FieldElement::from(5u128));
546 let b = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
547 let k = FieldElement::from(2u128);
548
549 let result = a.add_mul(k, &b);
550 assert_eq!(result.to_string(), "4*w1*w2 + 6*w1 + 13");
551 }
552
553 #[test]
554 fn add_mul_when_b_is_const() {
555 let a = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
557 let b = Expression::from_field(FieldElement::from(5u128));
558 let k = FieldElement::from(3u128);
559
560 let result = a.add_mul(k, &b);
561 assert_eq!(result.to_string(), "2*w1*w2 + 3*w1 + 19");
562 }
563
564 #[test]
565 fn add_mul_merges_linear_terms() {
566 let a = Expression::from_str("5*w1 + 3*w2").unwrap();
568 let b = Expression::from_str("2*w1 + 4*w3").unwrap();
569 let k = FieldElement::from(2u128);
570
571 let result = a.add_mul(k, &b);
572 assert_eq!(result.to_string(), "9*w1 + 3*w2 + 8*w3");
574 }
575
576 #[test]
577 fn add_mul_merges_mul_terms() {
578 let a = Expression::from_str("5*w1*w2 + 3*w3*w4").unwrap();
580 let b = Expression::from_str("2*w1*w2 + 4*w5*w6").unwrap();
581 let k = FieldElement::from(3u128);
582
583 let result = a.add_mul(k, &b);
584 assert_eq!(result.to_string(), "11*w1*w2 + 3*w3*w4 + 12*w5*w6");
587 }
588
589 #[test]
590 fn add_mul_cancels_terms_to_zero() {
591 let a = Expression::from_str("6*w1 + 3*w1*w2").unwrap();
593 let b = Expression::from_str("3*w1 + w1*w2").unwrap();
594 let k = FieldElement::from(-2i128);
595
596 let result = a.add_mul(k, &b);
597 assert_eq!(result.to_string(), "w1*w2");
600 }
601
602 #[test]
603 fn add_mul_maintains_sorted_order() {
604 let a = Expression::from_str("w5 + w1*w3").unwrap();
606 let b = Expression::from_str("w2 + w0*w1").unwrap();
607 let k = FieldElement::one();
608
609 let result = a.add_mul(k, &b);
610 assert!(result.is_sorted());
612 assert_eq!(result.to_string(), "w0*w1 + w1*w3 + w2 + w5");
613 }
614
615 #[test]
616 fn add_mul_with_constant_terms() {
617 let a = Expression::from_str("2*w1 + 10").unwrap();
619 let b = Expression::from_str("3*w2 + 5").unwrap();
620 let k = FieldElement::from(4u128);
621
622 let result = a.add_mul(k, &b);
623 assert_eq!(result.to_string(), "2*w1 + 12*w2 + 30");
625 }
626
627 #[test]
628 fn add_mul_complex_expression() {
629 let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11").unwrap();
631 let b = Expression::from_str("w1*w2 + 4*w5*w6 + 2*w1 + 6*w5 + 13").unwrap();
632 let k = FieldElement::from(2u128);
633
634 let result = a.add_mul(k, &b);
635 assert_eq!(result.to_string(), "4*w1*w2 + 3*w3*w4 + 8*w5*w6 + 9*w1 + 7*w3 + 12*w5 + 37");
639 }
640
641 #[test]
642 fn display_zero() {
643 let zero = Expression::<FieldElement>::default();
644 assert_eq!(zero.to_string(), "0");
645 }
646}