acir/native_types/expression/
operators.rs1use crate::native_types::Witness;
2use acir_field::AcirField;
3use std::ops::{Add, Mul, Neg, Sub};
4
5use super::Expression;
6
7impl<F: AcirField> Neg for &Expression<F> {
10 type Output = Expression<F>;
11 fn neg(self) -> Self::Output {
12 let mut mul_terms = self.mul_terms.clone();
13 for (q_m, _, _) in &mut mul_terms {
14 *q_m = -*q_m;
15 }
16
17 let mut linear_combinations = self.linear_combinations.clone();
18 for (q_k, _) in &mut linear_combinations {
19 *q_k = -*q_k;
20 }
21
22 Expression { mul_terms, linear_combinations, q_c: -self.q_c }
23 }
24}
25
26impl<F: AcirField> Neg for Expression<F> {
27 type Output = Expression<F>;
28 fn neg(mut self) -> Self::Output {
29 for (q_m, _, _) in &mut self.mul_terms {
30 *q_m = -*q_m;
31 }
32
33 for (q_k, _) in &mut self.linear_combinations {
34 *q_k = -*q_k;
35 }
36
37 self.q_c = -self.q_c;
38
39 self
40 }
41}
42
43impl<F: AcirField> Add<F> for Expression<F> {
46 type Output = Self;
47 fn add(self, rhs: F) -> Self::Output {
48 let q_c = self.q_c + rhs;
50
51 Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
52 }
53}
54
55impl<F: AcirField> Sub<F> for Expression<F> {
56 type Output = Self;
57 fn sub(self, rhs: F) -> Self::Output {
58 let q_c = self.q_c - rhs;
60
61 Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
62 }
63}
64
65impl<F: AcirField> Mul<F> for &Expression<F> {
66 type Output = Expression<F>;
67 fn mul(self, rhs: F) -> Self::Output {
68 let mul_terms: Vec<_> =
70 self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * rhs, *w_l, *w_r)).collect();
71
72 let lin_combinations: Vec<_> =
74 self.linear_combinations.iter().map(|(q_l, w_l)| (*q_l * rhs, *w_l)).collect();
75
76 let q_c = self.q_c * rhs;
78
79 Expression { mul_terms, q_c, linear_combinations: lin_combinations }
80 }
81}
82
83impl<F: AcirField> Add<Witness> for &Expression<F> {
86 type Output = Expression<F>;
87 fn add(self, rhs: Witness) -> Self::Output {
88 self + &Expression::from(rhs)
89 }
90}
91
92impl<F: AcirField> Add<&Expression<F>> for Witness {
93 type Output = Expression<F>;
94 #[inline]
95 fn add(self, rhs: &Expression<F>) -> Self::Output {
96 rhs + self
97 }
98}
99
100impl<F: AcirField> Sub<Witness> for &Expression<F> {
101 type Output = Expression<F>;
102 fn sub(self, rhs: Witness) -> Self::Output {
103 self - &Expression::from(rhs)
104 }
105}
106
107impl<F: AcirField> Sub<&Expression<F>> for Witness {
108 type Output = Expression<F>;
109 #[inline]
110 fn sub(self, rhs: &Expression<F>) -> Self::Output {
111 &Expression::from(self) - rhs
112 }
113}
114
115impl<F: AcirField> Add<&Expression<F>> for &Expression<F> {
120 type Output = Expression<F>;
121 fn add(self, rhs: &Expression<F>) -> Self::Output {
122 self.add_mul(F::one(), rhs)
123 }
124}
125
126impl<F: AcirField> Sub<&Expression<F>> for &Expression<F> {
127 type Output = Expression<F>;
128 fn sub(self, rhs: &Expression<F>) -> Self::Output {
129 self.add_mul(-F::one(), rhs)
130 }
131}
132
133impl<F: AcirField> Mul<&Expression<F>> for &Expression<F> {
134 type Output = Option<Expression<F>>;
135 fn mul(self, rhs: &Expression<F>) -> Self::Output {
136 if self.is_const() {
137 return Some(rhs * self.q_c);
138 } else if rhs.is_const() {
139 return Some(self * rhs.q_c);
140 } else if !(self.is_linear() && rhs.is_linear()) {
141 return None;
144 }
145
146 let mut output = Expression::from_field(self.q_c * rhs.q_c);
148
149 for lc in &self.linear_combinations {
152 let single = single_mul(lc.1, rhs);
153 output = output.add_mul(lc.0, &single);
154 }
155
156 if !rhs.q_c.is_zero() {
158 let self_linear = Expression {
159 mul_terms: Vec::new(),
160 linear_combinations: self.linear_combinations.clone(),
161 q_c: F::zero(),
162 };
163 output = output.add_mul(rhs.q_c, &self_linear);
164 }
165
166 if !self.q_c.is_zero() {
168 let rhs_linear = Expression {
169 mul_terms: Vec::new(),
170 linear_combinations: rhs.linear_combinations.clone(),
171 q_c: F::zero(),
172 };
173 output = output.add_mul(self.q_c, &rhs_linear);
174 }
175
176 Some(output)
177 }
178}
179
180fn single_mul<F: AcirField>(w: Witness, b: &Expression<F>) -> Expression<F> {
182 Expression {
183 mul_terms: b
184 .linear_combinations
185 .iter()
186 .map(|(a, wit)| {
187 let (wl, wr) = if w < *wit { (w, *wit) } else { (*wit, w) };
188 (*a, wl, wr)
189 })
190 .collect(),
191 ..Default::default()
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use crate::native_types::Expression;
198 use acir_field::{AcirField, FieldElement};
199
200 #[test]
201 fn add_smoke_test() {
202 let a = Expression::from_str("2*w2 + 2").unwrap();
203 let b = Expression::from_str("4*w4 + 1").unwrap();
204 let result = Expression::from_str("2*w2 + 4*w4 + 3").unwrap();
205 assert_eq!(&a + &b, result);
206
207 assert_eq!(&a + &b, &b + &a);
209 }
210
211 #[test]
212 fn mul_smoke_test() {
213 let a = Expression::from_str("2*w2 + 2").unwrap();
214 let b = Expression::from_str("4*w4 + 1").unwrap();
215 let result = Expression::from_str("8*w2*w4 + 2*w2 + 8*w4 + 2").unwrap();
216 assert_eq!((&a * &b).unwrap(), result);
217
218 assert_eq!(&a * &b, &b * &a);
220 }
221
222 #[test]
223 fn mul_by_zero_constant() {
224 let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap();
227 let zero: Expression<FieldElement> = Expression::zero();
228
229 let result = (&a * &zero).unwrap();
230 assert!(result.mul_terms.is_empty());
232 assert!(result.q_c.is_zero());
233 for (coeff, _) in &result.linear_combinations {
234 assert!(coeff.is_zero());
235 }
236
237 assert_eq!(&a * &zero, &zero * &a);
239 }
240
241 #[test]
242 fn mul_by_one_constant() {
243 let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap();
245 let one: Expression<FieldElement> = Expression::one();
246
247 let result = (&a * &one).unwrap();
248 assert_eq!(result, a);
249
250 assert_eq!(&a * &one, &one * &a);
252 }
253
254 #[test]
255 fn mul_by_scalar_constant() {
256 let a = Expression::from_str("2*w1 + 3*w2 + 4").unwrap();
258 let scalar = Expression::from_field(FieldElement::from(5u128));
259
260 let result = (&a * &scalar).unwrap();
261 assert_eq!(result.to_string(), "10*w1 + 15*w2 + 20");
262
263 assert_eq!(&a * &scalar, &scalar * &a);
265 }
266
267 #[test]
268 fn mul_two_constants() {
269 let a = Expression::from_field(FieldElement::from(3u128));
271 let b = Expression::from_field(FieldElement::from(7u128));
272
273 let result = (&a * &b).unwrap();
274 assert_eq!(result, Expression::from_field(FieldElement::from(21u128)));
275
276 assert_eq!(&a * &b, &b * &a);
278 }
279
280 #[test]
281 fn mul_linear_expressions() {
282 let a = Expression::from_str("2*w1 + 3*w2").unwrap();
284 let b = Expression::from_str("4*w3 + 5*w4").unwrap();
285
286 let result = (&a * &b).unwrap();
287 assert_eq!(result.to_string(), "8*w1*w3 + 10*w1*w4 + 12*w2*w3 + 15*w2*w4");
289
290 assert_eq!(&a * &b, &b * &a);
292 }
293
294 #[test]
295 fn mul_with_shared_witness() {
296 let a = Expression::from_str("2*w1 + 3*w2").unwrap();
298 let b = Expression::from_str("4*w1 + 5*w3").unwrap();
299
300 let result = (&a * &b).unwrap();
301 assert_eq!(result.to_string(), "8*w1*w1 + 12*w1*w2 + 10*w1*w3 + 15*w2*w3");
303
304 assert_eq!(&a * &b, &b * &a);
306 }
307
308 #[test]
309 fn mul_single_witness() {
310 let a = Expression::from_str("w1").unwrap();
312 let b = Expression::from_str("w1").unwrap();
313
314 let result = (&a * &b).unwrap();
315 assert_eq!(result.to_string(), "w1*w1");
316 }
317
318 #[test]
319 fn mul_with_constant_term() {
320 let a = Expression::from_str("2*w1 + 3").unwrap();
322 let b = Expression::from_str("4*w2 + 5").unwrap();
323
324 let result = (&a * &b).unwrap();
325 assert_eq!(result.to_string(), "8*w1*w2 + 10*w1 + 12*w2 + 15");
327
328 assert_eq!(&a * &b, &b * &a);
330 }
331
332 #[test]
333 fn mul_degree_two_fails() {
334 let a = Expression::from_str("2*w1*w2 + 3*w1").unwrap();
336 let b = Expression::from_str("4*w3 + 5").unwrap();
337
338 let result = &a * &b;
339 assert!(result.is_none(), "Multiplication should fail for degree > 2");
340
341 assert_eq!(&a * &b, &b * &a);
343 }
344
345 #[test]
346 fn mul_both_degree_two_fails() {
347 let a = Expression::from_str("w1*w2").unwrap();
349 let b = Expression::from_str("w3*w4").unwrap();
350
351 let result = &a * &b;
352 assert!(result.is_none(), "Multiplication of two degree-2 expressions should fail");
353
354 assert_eq!(&a * &b, &b * &a);
356 }
357
358 #[test]
359 fn mul_complex_linear_expressions() {
360 let a = Expression::from_str("2*w1 + 3*w2 + 4*w3 + 5").unwrap();
362 let b = Expression::from_str("6*w4 + 7*w5 + 8").unwrap();
363
364 let result = (&a * &b).unwrap();
365 assert_eq!(
369 result.to_string(),
370 "12*w1*w4 + 14*w1*w5 + 18*w2*w4 + 21*w2*w5 + 24*w3*w4 + 28*w3*w5 + 16*w1 + 24*w2 + 32*w3 + 30*w4 + 35*w5 + 40"
371 );
372
373 assert_eq!(&a * &b, &b * &a);
375 }
376
377 #[test]
378 fn mul_witness_ordering() {
379 let a = Expression::from_str("w5").unwrap();
381 let b = Expression::from_str("w2").unwrap();
382
383 let result = (&a * &b).unwrap();
384 assert_eq!(result.to_string(), "w2*w5");
386
387 assert_eq!(&a * &b, &b * &a);
389 }
390
391 #[test]
392 fn mul_result_is_sorted() {
393 let a = Expression::from_str("w3 + w1").unwrap();
395 let b = Expression::from_str("w4 + w2").unwrap();
396
397 let result = (&a * &b).unwrap();
398 for (_, wl, wr) in &result.mul_terms {
400 assert!(wl <= wr, "Witnesses in mul_terms should be ordered");
401 }
402 }
403
404 #[test]
405 fn neg_reference() {
406 let a = Expression::from_str("2*w1*w2 + 3*w1 + 5*w2 + 7").unwrap();
408 let result = -&a;
409
410 assert_eq!(result.to_string(), "-2*w1*w2 - 3*w1 - 5*w2 - 7");
411
412 assert_eq!(a.to_string(), "2*w1*w2 + 3*w1 + 5*w2 + 7");
414 }
415
416 #[test]
417 fn neg_owned() {
418 let a = Expression::from_str("2*w1*w2 + 3*w1 + 5*w2 + 7").unwrap();
420 let result = -a;
421
422 assert_eq!(result.to_string(), "-2*w1*w2 - 3*w1 - 5*w2 - 7");
423 }
424
425 #[test]
426 fn neg_zero() {
427 let zero: Expression<FieldElement> = Expression::zero();
429 let result = -&zero;
430
431 assert_eq!(result, Expression::zero());
432 }
433
434 #[test]
435 fn neg_constant() {
436 let a = Expression::from_field(FieldElement::from(42u128));
438 let result = -a;
439
440 assert_eq!(result.q_c, FieldElement::from(-42i128));
441 assert!(result.mul_terms.is_empty());
442 assert!(result.linear_combinations.is_empty());
443 }
444
445 #[test]
446 fn neg_linear_only() {
447 let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap();
449 let result = -a;
450
451 assert_eq!(result.to_string(), "-3*w1 - 5*w2 - 7");
452 }
453
454 #[test]
455 fn neg_mul_only() {
456 let a = Expression::from_str("2*w1*w2 + 4*w3*w4").unwrap();
458 let result = -a;
459
460 assert_eq!(result.to_string(), "-2*w1*w2 - 4*w3*w4");
461 }
462
463 #[test]
464 fn double_neg() {
465 let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap();
467 let result = -(-a.clone());
468
469 assert_eq!(result, a);
470 }
471
472 #[test]
473 fn neg_preserves_structure() {
474 let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w2 + 11").unwrap();
476 let result = -&a;
477
478 assert_eq!(result.mul_terms.len(), a.mul_terms.len());
479 assert_eq!(result.linear_combinations.len(), a.linear_combinations.len());
480 }
481}