acir/native_types/expression/
mod.rs1use crate::{circuit::PublicInputs, native_types::Witness};
2use acir_field::AcirField;
3use msgpack_tagged::MsgpackTagged;
4use serde::{Deserialize, Serialize};
5use std::cmp::Ordering;
6mod operators;
7
8#[derive(Clone, Debug, PartialEq, Eq, Hash)]
25#[derive(Serialize, Deserialize, MsgpackTagged)]
26#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
27pub struct Expression<F> {
28 #[tag(0)]
35 pub mul_terms: Vec<(F, Witness, Witness)>,
36
37 #[tag(1)]
42 pub linear_combinations: Vec<(F, Witness)>,
43
44 #[tag(2)]
48 pub q_c: F,
49}
50
51impl<F: AcirField> Default for Expression<F> {
52 fn default() -> Self {
53 Expression { mul_terms: Vec::new(), linear_combinations: Vec::new(), q_c: F::zero() }
54 }
55}
56
57impl<F: AcirField> std::fmt::Display for Expression<F> {
58 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59 display_expression(self, false, None, f)
60 }
61}
62
63impl<F> Expression<F> {
64 pub fn num_mul_terms(&self) -> usize {
66 self.mul_terms.len()
67 }
68
69 pub fn push_addition_term(&mut self, coefficient: F, variable: Witness) {
71 self.linear_combinations.push((coefficient, variable));
72 }
73
74 pub fn push_multiplication_term(&mut self, coefficient: F, lhs: Witness, rhs: Witness) {
76 self.mul_terms.push((coefficient, lhs, rhs));
77 }
78
79 pub fn is_const(&self) -> bool {
86 self.mul_terms.is_empty() && self.linear_combinations.is_empty()
87 }
88
89 pub fn to_const(&self) -> Option<&F> {
99 self.is_const().then_some(&self.q_c)
100 }
101
102 pub fn is_linear(&self) -> bool {
114 self.mul_terms.is_empty()
115 }
116
117 pub fn is_degree_one_univariate(&self) -> bool {
137 self.is_linear() && self.linear_combinations.len() == 1
138 }
139
140 pub fn sort(&mut self) {
144 for term in &mut self.mul_terms {
145 if term.1 > term.2 {
146 std::mem::swap(&mut term.1, &mut term.2);
147 }
148 }
149 self.mul_terms.sort_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)));
150 self.linear_combinations.sort_by_key(|a| a.1);
151 }
152
153 #[cfg(test)]
154 pub(crate) fn is_sorted(&self) -> bool {
155 self.mul_terms.iter().all(|term| term.1 <= term.2)
156 && self.mul_terms.iter().is_sorted_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)).is_le())
157 && self.linear_combinations.iter().is_sorted_by(|a, b| a.1.cmp(&b.1).is_le())
158 }
159}
160
161impl<F: AcirField> Expression<F> {
162 pub fn from_field(q_c: F) -> Self {
163 Self { q_c, ..Default::default() }
164 }
165
166 pub fn zero() -> Self {
167 Self::default()
168 }
169
170 pub fn is_zero(&self) -> bool {
171 *self == Self::zero()
172 }
173
174 pub fn one() -> Self {
175 Self::from_field(F::one())
176 }
177
178 pub fn is_one(&self) -> bool {
179 *self == Self::one()
180 }
181
182 pub fn to_witness(&self) -> Option<Witness> {
189 if self.is_degree_one_univariate() {
190 let (coefficient, variable) = self.linear_combinations[0];
195 let constant = self.q_c;
196
197 if coefficient.is_one() && constant.is_zero() {
198 return Some(variable);
199 }
200 }
201 None
202 }
203
204 pub fn add_mul(&self, k: F, b: &Self) -> Self {
206 if k.is_zero() {
207 return self.clone();
208 } else if self.is_const() {
209 let kb = b * k;
210 return kb + self.q_c;
211 } else if b.is_const() {
212 return self.clone() + (k * b.q_c);
213 }
214
215 let mut mul_terms: Vec<(F, Witness, Witness)> =
216 Vec::with_capacity(self.mul_terms.len() + b.mul_terms.len());
217 let mut linear_combinations: Vec<(F, Witness)> =
218 Vec::with_capacity(self.linear_combinations.len() + b.linear_combinations.len());
219 let q_c = self.q_c + k * b.q_c;
220
221 let mut i1 = 0; let mut i2 = 0; while i1 < self.linear_combinations.len() && i2 < b.linear_combinations.len() {
225 let (a_c, a_w) = self.linear_combinations[i1];
226 let (b_c, b_w) = b.linear_combinations[i2];
227
228 let (coeff, witness) = match a_w.cmp(&b_w) {
229 Ordering::Greater => {
230 i2 += 1;
231 (k * b_c, b_w)
232 }
233 Ordering::Less => {
234 i1 += 1;
235 (a_c, a_w)
236 }
237 Ordering::Equal => {
238 i1 += 1;
241 i2 += 1;
242 (a_c + k * b_c, a_w)
243 }
244 };
245
246 if !coeff.is_zero() {
247 linear_combinations.push((coeff, witness));
248 }
249 }
250
251 while i1 < self.linear_combinations.len() {
253 linear_combinations.push(self.linear_combinations[i1]);
254 i1 += 1;
255 }
256 while i2 < b.linear_combinations.len() {
257 let (b_c, b_w) = b.linear_combinations[i2];
258 let coeff = b_c * k;
259 if !coeff.is_zero() {
260 linear_combinations.push((coeff, b_w));
261 }
262 i2 += 1;
263 }
264
265 i1 = 0; i2 = 0; while i1 < self.mul_terms.len() && i2 < b.mul_terms.len() {
270 let (a_c, a_wl, a_wr) = self.mul_terms[i1];
271 let (b_c, b_wl, b_wr) = b.mul_terms[i2];
272
273 let (coeff, wl, wr) = match (a_wl, a_wr).cmp(&(b_wl, b_wr)) {
274 Ordering::Greater => {
275 i2 += 1;
276 (k * b_c, b_wl, b_wr)
277 }
278 Ordering::Less => {
279 i1 += 1;
280 (a_c, a_wl, a_wr)
281 }
282 Ordering::Equal => {
283 i2 += 1;
286 i1 += 1;
287 (a_c + k * b_c, a_wl, a_wr)
288 }
289 };
290
291 if !coeff.is_zero() {
292 mul_terms.push((coeff, wl, wr));
293 }
294 }
295
296 while i1 < self.mul_terms.len() {
298 mul_terms.push(self.mul_terms[i1]);
299 i1 += 1;
300 }
301 while i2 < b.mul_terms.len() {
302 let (b_c, b_wl, b_wr) = b.mul_terms[i2];
303 let coeff = b_c * k;
304 if coeff != F::zero() {
305 mul_terms.push((coeff, b_wl, b_wr));
306 }
307 i2 += 1;
308 }
309
310 Expression { mul_terms, linear_combinations, q_c }
311 }
312
313 pub fn width(&self) -> usize {
320 if self.mul_terms.len() > 1 {
321 unimplemented!("ICE - width() does not support expressions with multiple mul terms");
322 }
323
324 let mut width = 0;
325
326 for mul_term in &self.mul_terms {
327 assert_ne!(mul_term.0, F::zero());
329
330 let mut found_x = false;
331 let mut found_y = false;
332
333 for term in &self.linear_combinations {
334 let witness = &term.1;
335 let x = &mul_term.1;
336 let y = &mul_term.2;
337 if witness == x {
338 found_x = true;
339 }
340 if witness == y {
341 found_y = true;
342 }
343 if found_x & found_y {
344 break;
345 }
346 }
347
348 let multiplication_is_squaring = mul_term.1 == mul_term.2;
351
352 let mul_term_width_contribution = if !multiplication_is_squaring && (found_x & found_y)
353 {
354 0
358 } else if found_x || found_y {
359 1
362 } else {
363 2
365 };
366
367 width += mul_term_width_contribution;
368 }
369
370 width += self.linear_combinations.len();
371
372 width
373 }
374}
375
376impl<F: AcirField> From<F> for Expression<F> {
377 fn from(constant: F) -> Self {
378 Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() }
379 }
380}
381
382impl<F: AcirField> From<Witness> for Expression<F> {
383 fn from(wit: Witness) -> Self {
389 Expression {
390 q_c: F::zero(),
391 linear_combinations: vec![(F::one(), wit)],
392 mul_terms: Vec::new(),
393 }
394 }
395}
396
397pub(crate) fn display_expression<F: AcirField>(
403 expr: &Expression<F>,
404 as_equal_to_zero: bool,
405 return_values: Option<&PublicInputs>,
406 f: &mut std::fmt::Formatter<'_>,
407) -> std::fmt::Result {
408 let mut assignment_witness: Option<usize> = None;
411
412 let mut negate_coefficients = false;
416
417 let linear_witness_one = if as_equal_to_zero {
420 let linear_witness_one = return_values.and_then(|return_values| {
422 expr.linear_combinations.iter().enumerate().find(|(_, (coefficient, witness))| {
423 (coefficient.is_one() || (-*coefficient).is_one())
424 && return_values.0.contains(witness)
425 })
426 });
427 linear_witness_one.or_else(|| {
428 expr.linear_combinations
430 .iter()
431 .enumerate()
432 .filter(|(_, (coefficient, _))| coefficient.is_one() || (-*coefficient).is_one())
433 .max_by_key(|(_, (_, witness))| witness)
434 })
435 } else {
436 None
437 };
438
439 if let Some((index, (coefficient, witness))) = linear_witness_one {
442 assignment_witness = Some(index);
443 negate_coefficients = coefficient.is_one();
444 write!(f, "{witness} = ")?;
445 } else if as_equal_to_zero {
446 write!(f, "0 = ")?;
447 }
448
449 let mut printed_term = false;
450
451 for (coefficient, witness1, witness2) in &expr.mul_terms {
452 let witnesses = [*witness1, *witness2];
453 display_term(*coefficient, witnesses, printed_term, negate_coefficients, f)?;
454 printed_term = true;
455 }
456
457 for (index, (coefficient, witness)) in expr.linear_combinations.iter().enumerate() {
458 if assignment_witness
459 .is_some_and(|show_as_assignment_index| show_as_assignment_index == index)
460 {
461 continue;
463 }
464
465 let witnesses = [*witness];
466 display_term(*coefficient, witnesses, printed_term, negate_coefficients, f)?;
467 printed_term = true;
468 }
469
470 if expr.q_c.is_zero() {
471 if !printed_term {
472 write!(f, "0")?;
473 }
474 } else {
475 let coefficient = expr.q_c;
476 let coefficient = if negate_coefficients { -coefficient } else { coefficient };
477 let coefficient_as_string = coefficient.to_string();
478 let coefficient_is_negative = coefficient_as_string.starts_with('-');
479
480 if printed_term {
481 if coefficient_is_negative {
482 write!(f, " - ")?;
483 } else {
484 write!(f, " + ")?;
485 }
486 }
487
488 let coefficient =
489 if printed_term && coefficient_is_negative { -coefficient } else { coefficient };
490 write!(f, "{coefficient}")?;
491 }
492
493 Ok(())
494}
495
496fn display_term<F: AcirField, const N: usize>(
497 coefficient: F,
498 witnesses: [Witness; N],
499 printed_term: bool,
500 negate_coefficients: bool,
501 f: &mut std::fmt::Formatter<'_>,
502) -> std::fmt::Result {
503 let coefficient = if negate_coefficients { -coefficient } else { coefficient };
504 let coefficient_as_string = coefficient.to_string();
505 let coefficient_is_negative = coefficient_as_string.starts_with('-');
506
507 if printed_term {
508 if coefficient_is_negative {
509 write!(f, " - ")?;
510 } else {
511 write!(f, " + ")?;
512 }
513 }
514
515 let coefficient =
516 if printed_term && coefficient_is_negative { -coefficient } else { coefficient };
517
518 if coefficient.is_one() {
519 } else if (-coefficient).is_one() {
521 write!(f, "-")?;
522 } else {
523 write!(f, "{coefficient}*")?;
524 }
525
526 for (index, witness) in witnesses.iter().enumerate() {
527 if index != 0 {
528 write!(f, "*")?;
529 }
530 write!(f, "{witness}")?;
531 }
532
533 Ok(())
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use acir_field::FieldElement;
540
541 #[test]
542 fn add_mul_smoke_test() {
543 let a = Expression::from_str("2*w1*w2").unwrap();
544
545 let k = FieldElement::from(10u128);
546 let b = Expression::from_str("3*w0*w2 + 3*w1*w2 + 4*w4*w5 + 4*w4 + 1").unwrap();
547
548 let result = a.add_mul(k, &b);
549 assert_eq!(result.to_string(), "30*w0*w2 + 32*w1*w2 + 40*w4*w5 + 40*w4 + 10");
550 }
551
552 #[test]
553 fn add_mul_with_zero_coefficient() {
554 let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap();
556 let b = Expression::from_str("4*w2*w3 + 6*w2 + 7").unwrap();
557 let k = FieldElement::zero();
558
559 let result = a.add_mul(k, &b);
560 assert_eq!(result, a);
561 }
562
563 #[test]
564 fn add_mul_when_self_is_const() {
565 let a = Expression::from_field(FieldElement::from(5u128));
567 let b = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
568 let k = FieldElement::from(2u128);
569
570 let result = a.add_mul(k, &b);
571 assert_eq!(result.to_string(), "4*w1*w2 + 6*w1 + 13");
572 }
573
574 #[test]
575 fn add_mul_when_b_is_const() {
576 let a = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
578 let b = Expression::from_field(FieldElement::from(5u128));
579 let k = FieldElement::from(3u128);
580
581 let result = a.add_mul(k, &b);
582 assert_eq!(result.to_string(), "2*w1*w2 + 3*w1 + 19");
583 }
584
585 #[test]
586 fn add_mul_merges_linear_terms() {
587 let a = Expression::from_str("5*w1 + 3*w2").unwrap();
589 let b = Expression::from_str("2*w1 + 4*w3").unwrap();
590 let k = FieldElement::from(2u128);
591
592 let result = a.add_mul(k, &b);
593 assert_eq!(result.to_string(), "9*w1 + 3*w2 + 8*w3");
595 }
596
597 #[test]
598 fn add_mul_merges_mul_terms() {
599 let a = Expression::from_str("5*w1*w2 + 3*w3*w4").unwrap();
601 let b = Expression::from_str("2*w1*w2 + 4*w5*w6").unwrap();
602 let k = FieldElement::from(3u128);
603
604 let result = a.add_mul(k, &b);
605 assert_eq!(result.to_string(), "11*w1*w2 + 3*w3*w4 + 12*w5*w6");
608 }
609
610 #[test]
611 fn add_mul_cancels_terms_to_zero() {
612 let a = Expression::from_str("6*w1 + 3*w1*w2").unwrap();
614 let b = Expression::from_str("3*w1 + w1*w2").unwrap();
615 let k = FieldElement::from(-2i128);
616
617 let result = a.add_mul(k, &b);
618 assert_eq!(result.to_string(), "w1*w2");
621 }
622
623 #[test]
624 fn add_mul_maintains_sorted_order() {
625 let a = Expression::from_str("w5 + w1*w3").unwrap();
627 let b = Expression::from_str("w2 + w0*w1").unwrap();
628 let k = FieldElement::one();
629
630 let result = a.add_mul(k, &b);
631 assert!(result.is_sorted());
633 assert_eq!(result.to_string(), "w0*w1 + w1*w3 + w2 + w5");
634 }
635
636 #[test]
637 fn add_mul_with_constant_terms() {
638 let a = Expression::from_str("2*w1 + 10").unwrap();
640 let b = Expression::from_str("3*w2 + 5").unwrap();
641 let k = FieldElement::from(4u128);
642
643 let result = a.add_mul(k, &b);
644 assert_eq!(result.to_string(), "2*w1 + 12*w2 + 30");
646 }
647
648 #[test]
649 fn add_mul_complex_expression() {
650 let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11").unwrap();
652 let b = Expression::from_str("w1*w2 + 4*w5*w6 + 2*w1 + 6*w5 + 13").unwrap();
653 let k = FieldElement::from(2u128);
654
655 let result = a.add_mul(k, &b);
656 assert_eq!(result.to_string(), "4*w1*w2 + 3*w3*w4 + 8*w5*w6 + 9*w1 + 7*w3 + 12*w5 + 37");
660 }
661
662 #[test]
663 fn display_zero() {
664 let zero = Expression::<FieldElement>::default();
665 assert_eq!(zero.to_string(), "0");
666 }
667}