brillig_vm/
arithmetic.rs

1//! Implementations for [binary field operations][acir::brillig::Opcode::BinaryFieldOp] and
2//! [binary integer operations][acir::brillig::Opcode::BinaryIntOp].
3use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
4
5use acir::AcirField;
6use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize};
7use num_bigint::BigUint;
8use num_traits::{CheckedDiv, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
9
10use crate::memory::{MemoryTypeError, MemoryValue};
11
12#[derive(Debug, PartialEq, thiserror::Error)]
13pub(crate) enum BrilligArithmeticError {
14    #[error("Bit size for lhs {lhs_bit_size} does not match op bit size {op_bit_size}")]
15    MismatchedLhsBitSize { lhs_bit_size: u32, op_bit_size: u32 },
16    #[error("Bit size for rhs {rhs_bit_size} does not match op bit size {op_bit_size}")]
17    MismatchedRhsBitSize { rhs_bit_size: u32, op_bit_size: u32 },
18    #[error("Attempted to divide by zero")]
19    DivisionByZero,
20}
21
22/// Evaluate a binary operation on two FieldElement memory values.
23pub(crate) fn evaluate_binary_field_op<F: AcirField>(
24    op: &BinaryFieldOp,
25    lhs: MemoryValue<F>,
26    rhs: MemoryValue<F>,
27) -> Result<MemoryValue<F>, BrilligArithmeticError> {
28    let expect_field = |value: MemoryValue<F>, make_err: fn(u32, u32) -> BrilligArithmeticError| {
29        value.expect_field().map_err(|err| {
30            if let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err {
31                make_err(value_bit_size, expected_bit_size)
32            } else {
33                unreachable!("MemoryTypeError NotInteger is only produced by to_u128")
34            }
35        })
36    };
37    let a = expect_field(lhs, |vbs, ebs| BrilligArithmeticError::MismatchedLhsBitSize {
38        lhs_bit_size: vbs,
39        op_bit_size: ebs,
40    })?;
41    let b = expect_field(rhs, |vbs, ebs| BrilligArithmeticError::MismatchedRhsBitSize {
42        rhs_bit_size: vbs,
43        op_bit_size: ebs,
44    })?;
45
46    Ok(match op {
47        // Perform addition, subtraction, multiplication, and division based on the BinaryOp variant.
48        BinaryFieldOp::Add => MemoryValue::new_field(a + b),
49        BinaryFieldOp::Sub => MemoryValue::new_field(a - b),
50        BinaryFieldOp::Mul => MemoryValue::new_field(a * b),
51        BinaryFieldOp::Div => {
52            if b.is_zero() {
53                return Err(BrilligArithmeticError::DivisionByZero);
54            } else if b.is_one() {
55                MemoryValue::new_field(a)
56            } else if b == -F::one() {
57                MemoryValue::new_field(-a)
58            } else {
59                MemoryValue::new_field(a / b)
60            }
61        }
62        BinaryFieldOp::IntegerDiv => {
63            // IntegerDiv is only meant to represent unsigned integer division.
64            // The operands must be valid non-negative integers within the field's range.
65            // Because AcirField is modulo the prime field, it does not natively track
66            // "negative" numbers as any value is already reduced modulo the prime.
67            //
68            // Therefore, we do not check for negative inputs here. It is the responsibility
69            // of the code generator to ensure that operands for IntegerDiv are valid unsigned integers.
70            // The only runtime error we check for is division by zero.
71            if b.is_zero() {
72                return Err(BrilligArithmeticError::DivisionByZero);
73            } else {
74                let a_big = BigUint::from_bytes_be(&a.to_be_bytes());
75                let b_big = BigUint::from_bytes_be(&b.to_be_bytes());
76
77                let result = a_big / b_big;
78                MemoryValue::new_field(F::from_be_bytes_reduce(&result.to_bytes_be()))
79            }
80        }
81        BinaryFieldOp::Equals => (a == b).into(),
82        BinaryFieldOp::LessThan => (a < b).into(),
83        BinaryFieldOp::LessThanEquals => (a <= b).into(),
84    })
85}
86
87/// Evaluate a binary operation on two unsigned big integers with a given bit size.
88pub(crate) fn evaluate_binary_int_op<F: AcirField>(
89    op: &BinaryIntOp,
90    lhs: MemoryValue<F>,
91    rhs: MemoryValue<F>,
92    bit_size: IntegerBitSize,
93) -> Result<MemoryValue<F>, BrilligArithmeticError> {
94    match op {
95        BinaryIntOp::Add
96        | BinaryIntOp::Sub
97        | BinaryIntOp::Mul
98        | BinaryIntOp::Div
99        | BinaryIntOp::And
100        | BinaryIntOp::Or
101        | BinaryIntOp::Xor => match (lhs, rhs, bit_size) {
102            (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
103                evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1)
104            }
105            (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
106                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8)
107            }
108            (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
109                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16)
110            }
111            (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
112                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32)
113            }
114            (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
115                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64)
116            }
117            (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
118                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128)
119            }
120            (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
121                Err(BrilligArithmeticError::MismatchedLhsBitSize {
122                    lhs_bit_size: lhs.bit_size().to_u32::<F>(),
123                    op_bit_size: bit_size.into(),
124                })
125            }
126            (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
127                Err(BrilligArithmeticError::MismatchedRhsBitSize {
128                    rhs_bit_size: rhs.bit_size().to_u32::<F>(),
129                    op_bit_size: bit_size.into(),
130                })
131            }
132            _ => unreachable!("Invalid arguments are covered by the two arms above."),
133        },
134
135        BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
136            match (lhs, rhs, bit_size) {
137                (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
138                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
139                }
140                (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
141                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
142                }
143                (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
144                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
145                }
146                (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
147                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
148                }
149                (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
150                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
151                }
152                (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
153                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
154                }
155                (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
156                    Err(BrilligArithmeticError::MismatchedLhsBitSize {
157                        lhs_bit_size: lhs.bit_size().to_u32::<F>(),
158                        op_bit_size: bit_size.into(),
159                    })
160                }
161                (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
162                    Err(BrilligArithmeticError::MismatchedRhsBitSize {
163                        rhs_bit_size: rhs.bit_size().to_u32::<F>(),
164                        op_bit_size: bit_size.into(),
165                    })
166                }
167                _ => unreachable!("Invalid arguments are covered by the two arms above."),
168            }
169        }
170
171        BinaryIntOp::Shl | BinaryIntOp::Shr => match (lhs, rhs, bit_size) {
172            (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
173                if rhs {
174                    Ok(MemoryValue::U1(false))
175                } else {
176                    Ok(MemoryValue::U1(lhs))
177                }
178            }
179            (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
180                Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
181            }
182            (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
183                Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
184            }
185            (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
186                Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
187            }
188            (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
189                Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
190            }
191            (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
192                Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
193            }
194            _ => Err(BrilligArithmeticError::MismatchedLhsBitSize {
195                lhs_bit_size: lhs.bit_size().to_u32::<F>(),
196                op_bit_size: bit_size.into(),
197            }),
198        },
199    }
200}
201
202/// Evaluates binary operations on 1-bit unsigned integers (booleans).
203///
204/// # Returns
205/// - Ok(result) if successful.
206/// - Err([BrilligArithmeticError::DivisionByZero]) if division by zero occurs.
207///
208/// # Panics
209/// If an operation other than `Add`, `Sub`, `Mul`, `Div`, `And`, `Or`, `Xor`, `Equals`, `LessThan`,
210/// or `LessThanEquals` is supplied as an argument.
211fn evaluate_binary_int_op_u1(
212    op: &BinaryIntOp,
213    lhs: bool,
214    rhs: bool,
215) -> Result<bool, BrilligArithmeticError> {
216    let result = match op {
217        BinaryIntOp::Equals => lhs == rhs,
218        BinaryIntOp::LessThan => !lhs & rhs,
219        BinaryIntOp::LessThanEquals => lhs <= rhs,
220        BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs,
221        BinaryIntOp::Or => lhs | rhs,
222        BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
223        BinaryIntOp::Div => {
224            if !rhs {
225                return Err(BrilligArithmeticError::DivisionByZero);
226            } else {
227                lhs
228            }
229        }
230        _ => unreachable!("Operator not handled by this function: {op:?}"),
231    };
232    Ok(result)
233}
234
235/// Evaluates comparison operations (`Equals`, `LessThan`, `LessThanEquals`)
236/// between two values of an ordered type (e.g., fields are unordered).
237///
238/// # Panics
239/// If an unsupported operator is provided (i.e., not `Equals`, `LessThan`, or `LessThanEquals`).
240fn evaluate_binary_int_op_cmp<T: Ord + PartialEq>(op: &BinaryIntOp, lhs: T, rhs: T) -> bool {
241    match op {
242        BinaryIntOp::Equals => lhs == rhs,
243        BinaryIntOp::LessThan => lhs < rhs,
244        BinaryIntOp::LessThanEquals => lhs <= rhs,
245        _ => unreachable!("Operator not handled by this function: {op:?}"),
246    }
247}
248
249/// Evaluates shift operations (`Shl`, `Shr`) for unsigned integers.
250/// Ensures that shifting beyond the type width returns zero.
251///
252/// # Returns
253/// - Ok(result)
254///
255/// # Panics
256/// If an unsupported operator is provided (i.e., not `Shl` or `Shr`).
257fn evaluate_binary_int_op_shifts<T: ToPrimitive + Zero + Shl<Output = T> + Shr<Output = T>>(
258    op: &BinaryIntOp,
259    lhs: T,
260    rhs: T,
261) -> Result<T, BrilligArithmeticError> {
262    let bit_size = (size_of::<T>() * 8) as u128;
263    let rhs_val = rhs.to_u128().unwrap();
264    if rhs_val >= bit_size {
265        return Ok(T::zero());
266    }
267    match op {
268        BinaryIntOp::Shl => Ok(lhs << rhs),
269        BinaryIntOp::Shr => Ok(lhs >> rhs),
270        _ => unreachable!("Operator not handled by this function: {op:?}"),
271    }
272}
273
274/// Evaluates arithmetic or bitwise operations on unsigned integer types,
275/// using wrapping arithmetic for [add][BinaryIntOp::Add], [sub][BinaryIntOp::Sub], and [mul][BinaryIntOp::Mul].
276///
277/// # Returns
278/// - Ok(result) if successful.
279/// - Err([BrilligArithmeticError::DivisionByZero]) if division by zero occurs.
280///
281/// # Panics
282/// If there an operation other than Add, Sub, Mul, Div, And, Or, Xor is supplied as an argument.
283fn evaluate_binary_int_op_arith<
284    T: WrappingAdd
285        + WrappingSub
286        + WrappingMul
287        + CheckedDiv
288        + BitAnd<Output = T>
289        + BitOr<Output = T>
290        + BitXor<Output = T>,
291>(
292    op: &BinaryIntOp,
293    lhs: T,
294    rhs: T,
295) -> Result<T, BrilligArithmeticError> {
296    let result = match op {
297        BinaryIntOp::Add => lhs.wrapping_add(&rhs),
298        BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
299        BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
300        BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
301        BinaryIntOp::And => lhs & rhs,
302        BinaryIntOp::Or => lhs | rhs,
303        BinaryIntOp::Xor => lhs ^ rhs,
304        _ => unreachable!("Operator not handled by this function: {op:?}"),
305    };
306    Ok(result)
307}
308
309#[cfg(test)]
310mod int_ops {
311    use super::*;
312    use acir::{AcirField, FieldElement};
313
314    struct TestParams {
315        a: u128,
316        b: u128,
317        result: u128,
318    }
319
320    fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: IntegerBitSize) -> u128 {
321        let result_value: MemoryValue<FieldElement> = evaluate_binary_int_op(
322            op,
323            MemoryValue::new_integer(a, bit_size),
324            MemoryValue::new_integer(b, bit_size),
325            bit_size,
326        )
327        .unwrap();
328        // Convert back to u128
329        result_value.to_field().to_u128()
330    }
331
332    fn to_negative(a: u128, bit_size: IntegerBitSize) -> u128 {
333        assert!(a > 0);
334        if bit_size == IntegerBitSize::U128 {
335            0_u128.wrapping_sub(a)
336        } else {
337            let two_pow = 2_u128.pow(bit_size.into());
338            two_pow - a
339        }
340    }
341
342    fn evaluate_int_ops(test_params: Vec<TestParams>, op: BinaryIntOp, bit_size: IntegerBitSize) {
343        for test in test_params {
344            assert_eq!(evaluate_u128(&op, test.a, test.b, bit_size), test.result);
345        }
346    }
347
348    #[test]
349    fn add_test() {
350        let bit_size = IntegerBitSize::U8;
351
352        let test_ops = vec![
353            TestParams { a: 50, b: 100, result: 150 },
354            TestParams { a: 250, b: 10, result: 4 },
355            TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
356            TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
357            TestParams { a: 5, b: to_negative(6, bit_size), result: to_negative(1, bit_size) },
358        ];
359        evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
360
361        let bit_size = IntegerBitSize::U128;
362        let test_ops = vec![
363            TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
364            TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
365        ];
366
367        evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
368
369        // Mismatched bit sizes should error
370        assert_eq!(
371            evaluate_binary_int_op(
372                &BinaryIntOp::Add,
373                MemoryValue::<FieldElement>::U8(1),
374                MemoryValue::<FieldElement>::U16(2),
375                IntegerBitSize::U8
376            ),
377            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
378        );
379        assert_eq!(
380            evaluate_binary_int_op(
381                &BinaryIntOp::Add,
382                MemoryValue::<FieldElement>::U16(2),
383                MemoryValue::<FieldElement>::U8(1),
384                IntegerBitSize::U8
385            ),
386            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
387        );
388    }
389
390    #[test]
391    fn sub_test() {
392        let bit_size = IntegerBitSize::U8;
393
394        let test_ops = vec![
395            TestParams { a: 50, b: 30, result: 20 },
396            TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
397            TestParams { a: 5, b: to_negative(3, bit_size), result: 8 },
398            TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
399            TestParams { a: 254, b: to_negative(3, bit_size), result: 1 },
400        ];
401        evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
402
403        let bit_size = IntegerBitSize::U128;
404
405        let test_ops = vec![
406            TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
407            TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
408        ];
409        evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
410
411        // Mismatched bit sizes should error
412        assert_eq!(
413            evaluate_binary_int_op(
414                &BinaryIntOp::Sub,
415                MemoryValue::<FieldElement>::U8(1),
416                MemoryValue::<FieldElement>::U16(1),
417                IntegerBitSize::U8
418            ),
419            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
420        );
421        assert_eq!(
422            evaluate_binary_int_op(
423                &BinaryIntOp::Sub,
424                MemoryValue::<FieldElement>::U16(1),
425                MemoryValue::<FieldElement>::U8(1),
426                IntegerBitSize::U8
427            ),
428            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
429        );
430    }
431
432    #[test]
433    fn mul_test() {
434        let bit_size = IntegerBitSize::U8;
435
436        let test_ops = vec![
437            TestParams { a: 5, b: 3, result: 15 },
438            TestParams { a: 5, b: 100, result: 244 },
439            TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
440            TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
441            TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
442        ];
443
444        evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
445
446        let bit_size = IntegerBitSize::U64;
447        let a = 2_u128.pow(bit_size.into()) - 1;
448        let b = 3;
449
450        // ( 2**(n-1) - 1 ) * 3 = 2*2**(n-1) - 2 + (2**(n-1) - 1) => wraps to (2**(n-1) - 1) - 2
451        assert_eq!(evaluate_u128(&BinaryIntOp::Mul, a, b, bit_size), a - 2);
452
453        let bit_size = IntegerBitSize::U128;
454
455        let test_ops = vec![
456            TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
457            TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
458            TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
459        ];
460
461        evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
462
463        // Mismatched bit sizes should error
464        assert_eq!(
465            evaluate_binary_int_op(
466                &BinaryIntOp::Mul,
467                MemoryValue::<FieldElement>::U8(1),
468                MemoryValue::<FieldElement>::U16(1),
469                IntegerBitSize::U8
470            ),
471            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
472        );
473        assert_eq!(
474            evaluate_binary_int_op(
475                &BinaryIntOp::Mul,
476                MemoryValue::<FieldElement>::U16(1),
477                MemoryValue::<FieldElement>::U8(1),
478                IntegerBitSize::U8
479            ),
480            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
481        );
482    }
483
484    #[test]
485    fn div_test() {
486        let bit_size = IntegerBitSize::U8;
487
488        let test_ops =
489            vec![TestParams { a: 5, b: 3, result: 1 }, TestParams { a: 5, b: 10, result: 0 }];
490
491        evaluate_int_ops(test_ops, BinaryIntOp::Div, bit_size);
492
493        // Mismatched bit sizes should error
494        assert_eq!(
495            evaluate_binary_int_op(
496                &BinaryIntOp::Div,
497                MemoryValue::<FieldElement>::U8(1),
498                MemoryValue::<FieldElement>::U16(1),
499                IntegerBitSize::U8
500            ),
501            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
502        );
503        assert_eq!(
504            evaluate_binary_int_op(
505                &BinaryIntOp::Div,
506                MemoryValue::<FieldElement>::U16(1),
507                MemoryValue::<FieldElement>::U8(1),
508                IntegerBitSize::U8
509            ),
510            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
511        );
512
513        // Division by zero should error
514        assert_eq!(
515            evaluate_binary_int_op(
516                &BinaryIntOp::Div,
517                MemoryValue::<FieldElement>::U8(1),
518                MemoryValue::<FieldElement>::U8(0),
519                IntegerBitSize::U8
520            ),
521            Err(BrilligArithmeticError::DivisionByZero)
522        );
523    }
524
525    #[test]
526    fn shl_test() {
527        let bit_size = IntegerBitSize::U8;
528
529        let test_ops =
530            vec![TestParams { a: 1, b: 7, result: 128 }, TestParams { a: 5, b: 7, result: 128 }];
531
532        evaluate_int_ops(test_ops, BinaryIntOp::Shl, bit_size);
533        // Shifting more than bit width returns zero
534        assert_eq!(
535            evaluate_binary_int_op(
536                &BinaryIntOp::Shl,
537                MemoryValue::<FieldElement>::U8(1u8),
538                MemoryValue::<FieldElement>::U8(8u8),
539                IntegerBitSize::U8
540            ),
541            Ok(MemoryValue::<FieldElement>::U8(0u8))
542        );
543    }
544
545    #[test]
546    fn shr_test() {
547        let bit_size = IntegerBitSize::U8;
548
549        let test_ops =
550            vec![TestParams { a: 1, b: 0, result: 1 }, TestParams { a: 5, b: 1, result: 2 }];
551
552        evaluate_int_ops(test_ops, BinaryIntOp::Shr, bit_size);
553        // Shifting more than bit width returns zero
554        assert_eq!(
555            evaluate_binary_int_op(
556                &BinaryIntOp::Shr,
557                MemoryValue::<FieldElement>::U8(1u8),
558                MemoryValue::<FieldElement>::U8(8u8),
559                IntegerBitSize::U8
560            ),
561            Ok(MemoryValue::<FieldElement>::U8(0u8))
562        );
563    }
564
565    /// Reference implementation: compute `lhs op rhs` using u128 arithmetic,
566    /// then truncate to the given bit width.
567    fn reference_shift(op: &BinaryIntOp, lhs: u128, rhs: u128, bit_width: u32) -> u128 {
568        let mask = if bit_width == 128 { u128::MAX } else { (1u128 << bit_width) - 1 };
569        let lhs = lhs & mask;
570        let rhs = rhs & mask;
571        if rhs >= u128::from(bit_width) {
572            return 0;
573        }
574        let result = match op {
575            BinaryIntOp::Shl => lhs << rhs as u32,
576            BinaryIntOp::Shr => lhs >> rhs as u32,
577            _ => unreachable!(),
578        };
579        result & mask
580    }
581
582    proptest::proptest! {
583        #[test]
584        fn shift_u8_fuzz(lhs in 0u128..=u128::from(u8::MAX), rhs in 0u128..=u128::from(u8::MAX)) {
585            let bit_size = IntegerBitSize::U8;
586            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
587                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
588                let expected = reference_shift(&op, lhs, rhs, 8);
589                proptest::prop_assert_eq!(actual, expected);
590            }
591        }
592
593        #[test]
594        fn shift_u16_fuzz(lhs in 0u128..=u128::from(u16::MAX), rhs in 0u128..=u128::from(u16::MAX)) {
595            let bit_size = IntegerBitSize::U16;
596            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
597                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
598                let expected = reference_shift(&op, lhs, rhs, 16);
599                proptest::prop_assert_eq!(actual, expected);
600            }
601        }
602
603        #[test]
604        fn shift_u32_fuzz(lhs in 0u128..=u128::from(u32::MAX), rhs in 0u128..=u128::from(u32::MAX)) {
605            let bit_size = IntegerBitSize::U32;
606            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
607                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
608                let expected = reference_shift(&op, lhs, rhs, 32);
609                proptest::prop_assert_eq!(actual, expected);
610            }
611        }
612
613        #[test]
614        fn shift_u64_fuzz(lhs in 0u128..=u128::from(u64::MAX), rhs in 0u128..=u128::from(u64::MAX)) {
615            let bit_size = IntegerBitSize::U64;
616            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
617                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
618                let expected = reference_shift(&op, lhs, rhs, 64);
619                proptest::prop_assert_eq!(actual, expected);
620            }
621        }
622
623        #[test]
624        fn shift_u128_fuzz(lhs: u128, rhs: u128) {
625            let bit_size = IntegerBitSize::U128;
626            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
627                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
628                let expected = reference_shift(&op, lhs, rhs, 128);
629                proptest::prop_assert_eq!(actual, expected);
630            }
631        }
632    }
633
634    #[test]
635    fn comparison_ops_test() {
636        let bit_size = IntegerBitSize::U8;
637
638        // Equals
639        let test_ops = vec![
640            TestParams { a: 5, b: 5, result: 1 },
641            TestParams { a: 10, b: 5, result: 0 },
642            TestParams { a: 0, b: 0, result: 1 },
643        ];
644        evaluate_int_ops(test_ops, BinaryIntOp::Equals, bit_size);
645
646        // LessThan
647        let test_ops = vec![
648            TestParams { a: 4, b: 5, result: 1 },
649            TestParams { a: 5, b: 4, result: 0 },
650            TestParams { a: 5, b: 5, result: 0 },
651        ];
652        evaluate_int_ops(test_ops, BinaryIntOp::LessThan, bit_size);
653
654        // LessThanEquals
655        let test_ops = vec![
656            TestParams { a: 4, b: 5, result: 1 },
657            TestParams { a: 5, b: 4, result: 0 },
658            TestParams { a: 5, b: 5, result: 1 },
659        ];
660        evaluate_int_ops(test_ops, BinaryIntOp::LessThanEquals, bit_size);
661
662        // Mismatched bit sizes should error
663        assert_eq!(
664            evaluate_binary_int_op(
665                &BinaryIntOp::Equals,
666                MemoryValue::<FieldElement>::U8(1),
667                MemoryValue::<FieldElement>::U16(1),
668                IntegerBitSize::U8
669            ),
670            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
671        );
672        assert_eq!(
673            evaluate_binary_int_op(
674                &BinaryIntOp::LessThan,
675                MemoryValue::<FieldElement>::U16(1),
676                MemoryValue::<FieldElement>::U8(1),
677                IntegerBitSize::U8
678            ),
679            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
680        );
681        assert_eq!(
682            evaluate_binary_int_op(
683                &BinaryIntOp::LessThanEquals,
684                MemoryValue::<FieldElement>::U16(1),
685                MemoryValue::<FieldElement>::U8(1),
686                IntegerBitSize::U8
687            ),
688            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
689        );
690    }
691}
692
693#[cfg(test)]
694mod field_ops {
695    use super::*;
696    use acir::{AcirField, FieldElement};
697
698    struct TestParams {
699        a: FieldElement,
700        b: FieldElement,
701        result: FieldElement,
702    }
703
704    fn evaluate_field_u128(op: &BinaryFieldOp, a: FieldElement, b: FieldElement) -> FieldElement {
705        let result_value: MemoryValue<FieldElement> =
706            evaluate_binary_field_op(op, MemoryValue::new_field(a), MemoryValue::new_field(b))
707                .unwrap();
708        // Convert back to FieldElement
709        result_value.to_field()
710    }
711
712    fn evaluate_field_ops(test_params: Vec<TestParams>, op: BinaryFieldOp) {
713        for test in test_params {
714            assert_eq!(evaluate_field_u128(&op, test.a, test.b), test.result);
715        }
716    }
717
718    #[test]
719    fn add_test() {
720        let test_ops = vec![
721            TestParams { a: 1u32.into(), b: 2u32.into(), result: 3u32.into() },
722            TestParams { a: 5u32.into(), b: 10u32.into(), result: 15u32.into() },
723            TestParams { a: 250u32.into(), b: 10u32.into(), result: 260u32.into() },
724        ];
725        evaluate_field_ops(test_ops, BinaryFieldOp::Add);
726
727        // Mismatched bit sizes
728        assert_eq!(
729            evaluate_binary_field_op(
730                &BinaryFieldOp::Add,
731                MemoryValue::new_field(FieldElement::from(1u32)),
732                MemoryValue::<FieldElement>::U16(2),
733            ),
734            Err(BrilligArithmeticError::MismatchedRhsBitSize {
735                rhs_bit_size: 16,
736                op_bit_size: 254
737            })
738        );
739
740        assert_eq!(
741            evaluate_binary_field_op(
742                &BinaryFieldOp::Add,
743                MemoryValue::<FieldElement>::U16(1),
744                MemoryValue::new_field(FieldElement::from(2u32)),
745            ),
746            Err(BrilligArithmeticError::MismatchedLhsBitSize {
747                lhs_bit_size: 16,
748                op_bit_size: 254
749            })
750        );
751    }
752
753    #[test]
754    fn sub_test() {
755        let test_ops = vec![
756            TestParams { a: 5u32.into(), b: 3u32.into(), result: 2u32.into() },
757            TestParams { a: 2u32.into(), b: 10u32.into(), result: FieldElement::from(-8_i128) },
758        ];
759        evaluate_field_ops(test_ops, BinaryFieldOp::Sub);
760
761        // Mismatched bit sizes
762        assert_eq!(
763            evaluate_binary_field_op(
764                &BinaryFieldOp::Sub,
765                MemoryValue::new_field(FieldElement::from(1u32)),
766                MemoryValue::<FieldElement>::U16(1),
767            ),
768            Err(BrilligArithmeticError::MismatchedRhsBitSize {
769                rhs_bit_size: 16,
770                op_bit_size: 254
771            })
772        );
773
774        assert_eq!(
775            evaluate_binary_field_op(
776                &BinaryFieldOp::Sub,
777                MemoryValue::<FieldElement>::U16(1),
778                MemoryValue::new_field(FieldElement::from(1u32)),
779            ),
780            Err(BrilligArithmeticError::MismatchedLhsBitSize {
781                lhs_bit_size: 16,
782                op_bit_size: 254
783            })
784        );
785    }
786
787    #[test]
788    fn mul_test() {
789        let test_ops = vec![
790            TestParams { a: 2u32.into(), b: 3u32.into(), result: 6u32.into() },
791            TestParams { a: 10u32.into(), b: 25u32.into(), result: 250u32.into() },
792        ];
793        evaluate_field_ops(test_ops, BinaryFieldOp::Mul);
794
795        // Mismatched bit sizes
796        assert_eq!(
797            evaluate_binary_field_op(
798                &BinaryFieldOp::Mul,
799                MemoryValue::new_field(FieldElement::from(1u32)),
800                MemoryValue::<FieldElement>::U16(1),
801            ),
802            Err(BrilligArithmeticError::MismatchedRhsBitSize {
803                rhs_bit_size: 16,
804                op_bit_size: 254
805            })
806        );
807
808        assert_eq!(
809            evaluate_binary_field_op(
810                &BinaryFieldOp::Mul,
811                MemoryValue::<FieldElement>::U16(1),
812                MemoryValue::new_field(FieldElement::from(1u32)),
813            ),
814            Err(BrilligArithmeticError::MismatchedLhsBitSize {
815                lhs_bit_size: 16,
816                op_bit_size: 254
817            })
818        );
819    }
820
821    #[test]
822    fn div_test() {
823        let test_ops = vec![
824            TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
825            TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
826            TestParams {
827                a: 10u32.into(),
828                b: FieldElement::from(-1_i128),
829                result: FieldElement::from(-10_i128),
830            },
831            TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
832            // Field division is a * 1/b. The inverse of 20 is 7660885005143746327786242010840046280991927540145612020294371465301532973466
833            TestParams {
834                a: 10u32.into(),
835                b: 20u32.into(),
836                result: FieldElement::try_from_str(
837                    "10944121435919637611123202872628637544274182200208017171849102093287904247809",
838                )
839                .unwrap(),
840            },
841            // The inverse of 7 is 3126891838834182174606629392179610726935480628630862049099743455225115499374.
842            TestParams {
843                a: 100u32.into(),
844                b: 7u32.into(),
845                result: FieldElement::try_from_str(
846                    "6253783677668364349213258784359221453870961257261724098199486910450230998762",
847                )
848                .unwrap(),
849            },
850        ];
851        evaluate_field_ops(test_ops, BinaryFieldOp::Div);
852
853        // Division by zero
854        assert_eq!(
855            evaluate_binary_field_op(
856                &BinaryFieldOp::Div,
857                MemoryValue::new_field(FieldElement::from(1u128)),
858                MemoryValue::new_field(FieldElement::zero()),
859            ),
860            Err(BrilligArithmeticError::DivisionByZero)
861        );
862
863        // Mismatched bit sizes
864        assert_eq!(
865            evaluate_binary_field_op(
866                &BinaryFieldOp::Div,
867                MemoryValue::new_field(FieldElement::from(1u32)),
868                MemoryValue::<FieldElement>::U16(1),
869            ),
870            Err(BrilligArithmeticError::MismatchedRhsBitSize {
871                rhs_bit_size: 16,
872                op_bit_size: 254
873            })
874        );
875
876        assert_eq!(
877            evaluate_binary_field_op(
878                &BinaryFieldOp::Div,
879                MemoryValue::<FieldElement>::U16(1),
880                MemoryValue::new_field(FieldElement::from(1u32)),
881            ),
882            Err(BrilligArithmeticError::MismatchedLhsBitSize {
883                lhs_bit_size: 16,
884                op_bit_size: 254
885            })
886        );
887    }
888
889    #[test]
890    fn integer_div_test() {
891        let test_ops = vec![
892            TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
893            TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
894            // Negative numbers are treated as large unsigned numbers, thus we expect a result of 0 here
895            TestParams { a: 10u32.into(), b: FieldElement::from(-1_i128), result: 0u32.into() },
896            TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
897            TestParams { a: 10u32.into(), b: 20u32.into(), result: 0u32.into() },
898            // 100 / 7 == 14 with a remainder of 2. The remainder is discarded.
899            TestParams { a: 100u32.into(), b: 7u32.into(), result: 14u32.into() },
900        ];
901        evaluate_field_ops(test_ops, BinaryFieldOp::IntegerDiv);
902
903        // Division by zero should error
904        assert_eq!(
905            evaluate_binary_field_op(
906                &BinaryFieldOp::IntegerDiv,
907                MemoryValue::new_field(FieldElement::from(1u128)),
908                MemoryValue::new_field(FieldElement::zero()),
909            ),
910            Err(BrilligArithmeticError::DivisionByZero)
911        );
912
913        // Mismatched bit sizes should error
914        assert_eq!(
915            evaluate_binary_field_op(
916                &BinaryFieldOp::IntegerDiv,
917                MemoryValue::new_field(FieldElement::from(1u32)),
918                MemoryValue::<FieldElement>::U16(1),
919            ),
920            Err(BrilligArithmeticError::MismatchedRhsBitSize {
921                rhs_bit_size: 16,
922                op_bit_size: 254
923            })
924        );
925
926        assert_eq!(
927            evaluate_binary_field_op(
928                &BinaryFieldOp::IntegerDiv,
929                MemoryValue::<FieldElement>::U16(1),
930                MemoryValue::new_field(FieldElement::from(1u32)),
931            ),
932            Err(BrilligArithmeticError::MismatchedLhsBitSize {
933                lhs_bit_size: 16,
934                op_bit_size: 254
935            })
936        );
937    }
938
939    #[test]
940    fn comparison_ops_test() {
941        // Equals
942        let test_ops = vec![
943            TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
944            TestParams { a: 10u32.into(), b: 5u32.into(), result: 0u32.into() },
945            TestParams { a: 0u32.into(), b: 0u32.into(), result: 1u32.into() },
946        ];
947        evaluate_field_ops(test_ops, BinaryFieldOp::Equals);
948
949        // LessThan
950        let test_ops = vec![
951            TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
952            TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
953            TestParams { a: 5u32.into(), b: 5u32.into(), result: 0u32.into() },
954        ];
955        evaluate_field_ops(test_ops, BinaryFieldOp::LessThan);
956
957        // LessThanEquals
958        let test_ops = vec![
959            TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
960            TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
961            TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
962        ];
963        evaluate_field_ops(test_ops, BinaryFieldOp::LessThanEquals);
964
965        // Mismatched bit sizes should error
966        assert_eq!(
967            evaluate_binary_field_op(
968                &BinaryFieldOp::Equals,
969                MemoryValue::new_field(1u32.into()),
970                MemoryValue::<FieldElement>::U16(1),
971            ),
972            Err(BrilligArithmeticError::MismatchedRhsBitSize {
973                rhs_bit_size: 16,
974                op_bit_size: 254
975            })
976        );
977
978        assert_eq!(
979            evaluate_binary_field_op(
980                &BinaryFieldOp::LessThan,
981                MemoryValue::<FieldElement>::U16(1),
982                MemoryValue::new_field(1u32.into()),
983            ),
984            Err(BrilligArithmeticError::MismatchedLhsBitSize {
985                lhs_bit_size: 16,
986                op_bit_size: 254
987            })
988        );
989
990        assert_eq!(
991            evaluate_binary_field_op(
992                &BinaryFieldOp::LessThanEquals,
993                MemoryValue::<FieldElement>::U16(1),
994                MemoryValue::new_field(1u32.into()),
995            ),
996            Err(BrilligArithmeticError::MismatchedLhsBitSize {
997                lhs_bit_size: 16,
998                op_bit_size: 254
999            })
1000        );
1001    }
1002}