acir/native_types/expression/
operators.rs

1use crate::native_types::Witness;
2use acir_field::AcirField;
3use std::ops::{Add, Mul, Neg, Sub};
4
5use super::Expression;
6
7// Negation
8
9impl<F: AcirField> Neg for &Expression<F> {
10    type Output = Expression<F>;
11    fn neg(self) -> Self::Output {
12        let mut mul_terms = self.mul_terms.clone();
13        for (q_m, _, _) in &mut mul_terms {
14            *q_m = -*q_m;
15        }
16
17        let mut linear_combinations = self.linear_combinations.clone();
18        for (q_k, _) in &mut linear_combinations {
19            *q_k = -*q_k;
20        }
21
22        Expression { mul_terms, linear_combinations, q_c: -self.q_c }
23    }
24}
25
26impl<F: AcirField> Neg for Expression<F> {
27    type Output = Expression<F>;
28    fn neg(mut self) -> Self::Output {
29        for (q_m, _, _) in &mut self.mul_terms {
30            *q_m = -*q_m;
31        }
32
33        for (q_k, _) in &mut self.linear_combinations {
34            *q_k = -*q_k;
35        }
36
37        self.q_c = -self.q_c;
38
39        self
40    }
41}
42
43// FieldElement
44
45impl<F: AcirField> Add<F> for Expression<F> {
46    type Output = Self;
47    fn add(self, rhs: F) -> Self::Output {
48        // Increase the constant
49        let q_c = self.q_c + rhs;
50
51        Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
52    }
53}
54
55impl<F: AcirField> Sub<F> for Expression<F> {
56    type Output = Self;
57    fn sub(self, rhs: F) -> Self::Output {
58        // Increase the constant
59        let q_c = self.q_c - rhs;
60
61        Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
62    }
63}
64
65impl<F: AcirField> Mul<F> for &Expression<F> {
66    type Output = Expression<F>;
67    fn mul(self, rhs: F) -> Self::Output {
68        // Scale the mul terms
69        let mul_terms: Vec<_> =
70            self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * rhs, *w_l, *w_r)).collect();
71
72        // Scale the linear combinations terms
73        let lin_combinations: Vec<_> =
74            self.linear_combinations.iter().map(|(q_l, w_l)| (*q_l * rhs, *w_l)).collect();
75
76        // Scale the constant
77        let q_c = self.q_c * rhs;
78
79        Expression { mul_terms, q_c, linear_combinations: lin_combinations }
80    }
81}
82
83// Witness
84
85impl<F: AcirField> Add<Witness> for &Expression<F> {
86    type Output = Expression<F>;
87    fn add(self, rhs: Witness) -> Self::Output {
88        self + &Expression::from(rhs)
89    }
90}
91
92impl<F: AcirField> Add<&Expression<F>> for Witness {
93    type Output = Expression<F>;
94    #[inline]
95    fn add(self, rhs: &Expression<F>) -> Self::Output {
96        rhs + self
97    }
98}
99
100impl<F: AcirField> Sub<Witness> for &Expression<F> {
101    type Output = Expression<F>;
102    fn sub(self, rhs: Witness) -> Self::Output {
103        self - &Expression::from(rhs)
104    }
105}
106
107impl<F: AcirField> Sub<&Expression<F>> for Witness {
108    type Output = Expression<F>;
109    #[inline]
110    fn sub(self, rhs: &Expression<F>) -> Self::Output {
111        &Expression::from(self) - rhs
112    }
113}
114
115// Mul<Witness> is not implemented as this could result in degree 3 terms.
116
117// Expression
118
119impl<F: AcirField> Add<&Expression<F>> for &Expression<F> {
120    type Output = Expression<F>;
121    fn add(self, rhs: &Expression<F>) -> Self::Output {
122        self.add_mul(F::one(), rhs)
123    }
124}
125
126impl<F: AcirField> Sub<&Expression<F>> for &Expression<F> {
127    type Output = Expression<F>;
128    fn sub(self, rhs: &Expression<F>) -> Self::Output {
129        self.add_mul(-F::one(), rhs)
130    }
131}
132
133impl<F: AcirField> Mul<&Expression<F>> for &Expression<F> {
134    type Output = Option<Expression<F>>;
135    fn mul(self, rhs: &Expression<F>) -> Self::Output {
136        if self.is_const() {
137            return Some(rhs * self.q_c);
138        } else if rhs.is_const() {
139            return Some(self * rhs.q_c);
140        } else if !(self.is_linear() && rhs.is_linear()) {
141            // `Expression`s can only represent terms which are up to degree 2.
142            // We then disallow multiplication of `Expression`s which have degree 2 terms.
143            return None;
144        }
145
146        // Start with the constant term: q_c_self * q_c_rhs
147        let mut output = Expression::from_field(self.q_c * rhs.q_c);
148
149        // 'each linear term in self' * 'each linear term in rhs'
150        // XXX: This has a quadratic cost that can be improved, but for now we favor simplicity.
151        for lc in &self.linear_combinations {
152            let single = single_mul(lc.1, rhs);
153            output = output.add_mul(lc.0, &single);
154        }
155
156        // Add linear terms from self scaled by rhs's constant: self.linear * rhs.q_c
157        if !rhs.q_c.is_zero() {
158            let self_linear = Expression {
159                mul_terms: Vec::new(),
160                linear_combinations: self.linear_combinations.clone(),
161                q_c: F::zero(),
162            };
163            output = output.add_mul(rhs.q_c, &self_linear);
164        }
165
166        // Add linear terms from rhs scaled by self's constant: rhs.linear * self.q_c
167        if !self.q_c.is_zero() {
168            let rhs_linear = Expression {
169                mul_terms: Vec::new(),
170                linear_combinations: rhs.linear_combinations.clone(),
171                q_c: F::zero(),
172            };
173            output = output.add_mul(self.q_c, &rhs_linear);
174        }
175
176        Some(output)
177    }
178}
179
180/// Returns `w*b.linear_combinations`
181fn single_mul<F: AcirField>(w: Witness, b: &Expression<F>) -> Expression<F> {
182    Expression {
183        mul_terms: b
184            .linear_combinations
185            .iter()
186            .map(|(a, wit)| {
187                let (wl, wr) = if w < *wit { (w, *wit) } else { (*wit, w) };
188                (*a, wl, wr)
189            })
190            .collect(),
191        ..Default::default()
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use crate::native_types::Expression;
198    use acir_field::{AcirField, FieldElement};
199
200    #[test]
201    fn add_smoke_test() {
202        let a = Expression::from_str("2*w2 + 2").unwrap();
203        let b = Expression::from_str("4*w4 + 1").unwrap();
204        let result = Expression::from_str("2*w2 + 4*w4 + 3").unwrap();
205        assert_eq!(&a + &b, result);
206
207        // Enforce commutativity
208        assert_eq!(&a + &b, &b + &a);
209    }
210
211    #[test]
212    fn mul_smoke_test() {
213        let a = Expression::from_str("2*w2 + 2").unwrap();
214        let b = Expression::from_str("4*w4 + 1").unwrap();
215        let result = Expression::from_str("8*w2*w4 + 2*w2 + 8*w4 + 2").unwrap();
216        assert_eq!((&a * &b).unwrap(), result);
217
218        // Enforce commutativity
219        assert_eq!(&a * &b, &b * &a);
220    }
221
222    #[test]
223    fn mul_by_zero_constant() {
224        // Multiplying by zero should give zero (with zero coefficients)
225        // Note: The implementation may leave zero-coefficient terms in place
226        let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap();
227        let zero: Expression<FieldElement> = Expression::zero();
228
229        let result = (&a * &zero).unwrap();
230        // All terms should have zero coefficients and the constant should be zero
231        assert!(result.mul_terms.is_empty());
232        assert!(result.q_c.is_zero());
233        for (coeff, _) in &result.linear_combinations {
234            assert!(coeff.is_zero());
235        }
236
237        // Enforce commutativity
238        assert_eq!(&a * &zero, &zero * &a);
239    }
240
241    #[test]
242    fn mul_by_one_constant() {
243        // Multiplying by one should give the same expression
244        let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap();
245        let one: Expression<FieldElement> = Expression::one();
246
247        let result = (&a * &one).unwrap();
248        assert_eq!(result, a);
249
250        // Enforce commutativity
251        assert_eq!(&a * &one, &one * &a);
252    }
253
254    #[test]
255    fn mul_by_scalar_constant() {
256        // Multiplying by a constant should scale all terms
257        let a = Expression::from_str("2*w1 + 3*w2 + 4").unwrap();
258        let scalar = Expression::from_field(FieldElement::from(5u128));
259
260        let result = (&a * &scalar).unwrap();
261        assert_eq!(result.to_string(), "10*w1 + 15*w2 + 20");
262
263        // Enforce commutativity
264        assert_eq!(&a * &scalar, &scalar * &a);
265    }
266
267    #[test]
268    fn mul_two_constants() {
269        // Multiplying two constants
270        let a = Expression::from_field(FieldElement::from(3u128));
271        let b = Expression::from_field(FieldElement::from(7u128));
272
273        let result = (&a * &b).unwrap();
274        assert_eq!(result, Expression::from_field(FieldElement::from(21u128)));
275
276        // Enforce commutativity
277        assert_eq!(&a * &b, &b * &a);
278    }
279
280    #[test]
281    fn mul_linear_expressions() {
282        // Test multiplication of two linear expressions (no constants)
283        let a = Expression::from_str("2*w1 + 3*w2").unwrap();
284        let b = Expression::from_str("4*w3 + 5*w4").unwrap();
285
286        let result = (&a * &b).unwrap();
287        // (2*w1 + 3*w2) * (4*w3 + 5*w4) = 8*w1*w3 + 10*w1*w4 + 12*w2*w3 + 15*w2*w4
288        assert_eq!(result.to_string(), "8*w1*w3 + 10*w1*w4 + 12*w2*w3 + 15*w2*w4");
289
290        // Enforce commutativity
291        assert_eq!(&a * &b, &b * &a);
292    }
293
294    #[test]
295    fn mul_with_shared_witness() {
296        // Test multiplication where both expressions share a witness
297        let a = Expression::from_str("2*w1 + 3*w2").unwrap();
298        let b = Expression::from_str("4*w1 + 5*w3").unwrap();
299
300        let result = (&a * &b).unwrap();
301        // (2*w1 + 3*w2) * (4*w1 + 5*w3) = 8*w1*w1 + 10*w1*w3 + 12*w1*w2 + 15*w2*w3
302        assert_eq!(result.to_string(), "8*w1*w1 + 12*w1*w2 + 10*w1*w3 + 15*w2*w3");
303
304        // Enforce commutativity
305        assert_eq!(&a * &b, &b * &a);
306    }
307
308    #[test]
309    fn mul_single_witness() {
310        // Test squaring a single witness: (w1) * (w1) = w1*w1
311        let a = Expression::from_str("w1").unwrap();
312        let b = Expression::from_str("w1").unwrap();
313
314        let result = (&a * &b).unwrap();
315        assert_eq!(result.to_string(), "w1*w1");
316    }
317
318    #[test]
319    fn mul_with_constant_term() {
320        // Test multiplication where one expression has a constant term
321        let a = Expression::from_str("2*w1 + 3").unwrap();
322        let b = Expression::from_str("4*w2 + 5").unwrap();
323
324        let result = (&a * &b).unwrap();
325        // (2*w1 + 3) * (4*w2 + 5) = 8*w1*w2 + 10*w1 + 12*w2 + 15
326        assert_eq!(result.to_string(), "8*w1*w2 + 10*w1 + 12*w2 + 15");
327
328        // Enforce commutativity
329        assert_eq!(&a * &b, &b * &a);
330    }
331
332    #[test]
333    fn mul_degree_two_fails() {
334        // Multiplying expressions that would result in degree > 2 should return None
335        let a = Expression::from_str("2*w1*w2 + 3*w1").unwrap();
336        let b = Expression::from_str("4*w3 + 5").unwrap();
337
338        let result = &a * &b;
339        assert!(result.is_none(), "Multiplication should fail for degree > 2");
340
341        // Enforce commutativity
342        assert_eq!(&a * &b, &b * &a);
343    }
344
345    #[test]
346    fn mul_both_degree_two_fails() {
347        // Multiplying two degree-2 expressions should fail
348        let a = Expression::from_str("w1*w2").unwrap();
349        let b = Expression::from_str("w3*w4").unwrap();
350
351        let result = &a * &b;
352        assert!(result.is_none(), "Multiplication of two degree-2 expressions should fail");
353
354        // Enforce commutativity
355        assert_eq!(&a * &b, &b * &a);
356    }
357
358    #[test]
359    fn mul_complex_linear_expressions() {
360        // Test a more complex multiplication
361        let a = Expression::from_str("2*w1 + 3*w2 + 4*w3 + 5").unwrap();
362        let b = Expression::from_str("6*w4 + 7*w5 + 8").unwrap();
363
364        let result = (&a * &b).unwrap();
365        // (2*w1 + 3*w2 + 4*w3 + 5) * (6*w4 + 7*w5 + 8)
366        // = 12*w1*w4 + 14*w1*w5 + 18*w2*w4 + 21*w2*w5 + 24*w3*w4 + 28*w3*w5
367        //   + 16*w1 + 24*w2 + 32*w3 + 30*w4 + 35*w5 + 40
368        assert_eq!(
369            result.to_string(),
370            "12*w1*w4 + 14*w1*w5 + 18*w2*w4 + 21*w2*w5 + 24*w3*w4 + 28*w3*w5 + 16*w1 + 24*w2 + 32*w3 + 30*w4 + 35*w5 + 40"
371        );
372
373        // Enforce commutativity
374        assert_eq!(&a * &b, &b * &a);
375    }
376
377    #[test]
378    fn mul_witness_ordering() {
379        // Test that witness pairs are ordered correctly (smaller index first)
380        let a = Expression::from_str("w5").unwrap();
381        let b = Expression::from_str("w2").unwrap();
382
383        let result = (&a * &b).unwrap();
384        // Should be w2*w5, not w5*w2
385        assert_eq!(result.to_string(), "w2*w5");
386
387        // Enforce commutativity
388        assert_eq!(&a * &b, &b * &a);
389    }
390
391    #[test]
392    fn mul_result_is_sorted() {
393        // Verify the witness ordering in mul_terms is correct
394        let a = Expression::from_str("w3 + w1").unwrap();
395        let b = Expression::from_str("w4 + w2").unwrap();
396
397        let result = (&a * &b).unwrap();
398        // Verify that each mul_term has properly ordered witnesses (smaller first)
399        for (_, wl, wr) in &result.mul_terms {
400            assert!(wl <= wr, "Witnesses in mul_terms should be ordered");
401        }
402    }
403
404    #[test]
405    fn neg_reference() {
406        // Test negation of a reference (uses clone + in-place negate)
407        let a = Expression::from_str("2*w1*w2 + 3*w1 + 5*w2 + 7").unwrap();
408        let result = -&a;
409
410        assert_eq!(result.to_string(), "-2*w1*w2 - 3*w1 - 5*w2 - 7");
411
412        // Original should be unchanged
413        assert_eq!(a.to_string(), "2*w1*w2 + 3*w1 + 5*w2 + 7");
414    }
415
416    #[test]
417    fn neg_owned() {
418        // Test negation of an owned expression (in-place, no clone)
419        let a = Expression::from_str("2*w1*w2 + 3*w1 + 5*w2 + 7").unwrap();
420        let result = -a;
421
422        assert_eq!(result.to_string(), "-2*w1*w2 - 3*w1 - 5*w2 - 7");
423    }
424
425    #[test]
426    fn neg_zero() {
427        // Negating zero should give zero
428        let zero: Expression<FieldElement> = Expression::zero();
429        let result = -&zero;
430
431        assert_eq!(result, Expression::zero());
432    }
433
434    #[test]
435    fn neg_constant() {
436        // Negating a constant expression
437        let a = Expression::from_field(FieldElement::from(42u128));
438        let result = -a;
439
440        assert_eq!(result.q_c, FieldElement::from(-42i128));
441        assert!(result.mul_terms.is_empty());
442        assert!(result.linear_combinations.is_empty());
443    }
444
445    #[test]
446    fn neg_linear_only() {
447        // Negating an expression with only linear terms
448        let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap();
449        let result = -a;
450
451        assert_eq!(result.to_string(), "-3*w1 - 5*w2 - 7");
452    }
453
454    #[test]
455    fn neg_mul_only() {
456        // Negating an expression with only multiplication terms
457        let a = Expression::from_str("2*w1*w2 + 4*w3*w4").unwrap();
458        let result = -a;
459
460        assert_eq!(result.to_string(), "-2*w1*w2 - 4*w3*w4");
461    }
462
463    #[test]
464    fn double_neg() {
465        // Double negation should give back the original
466        let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap();
467        let result = -(-a.clone());
468
469        assert_eq!(result, a);
470    }
471
472    #[test]
473    fn neg_preserves_structure() {
474        // Negation should preserve the structure (number of terms)
475        let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w2 + 11").unwrap();
476        let result = -&a;
477
478        assert_eq!(result.mul_terms.len(), a.mul_terms.len());
479        assert_eq!(result.linear_combinations.len(), a.linear_combinations.len());
480    }
481}