brillig_vm/
arithmetic.rs

1//! Implementations for [binary field operations][acir::brillig::Opcode::BinaryFieldOp] and
2//! [binary integer operations][acir::brillig::Opcode::BinaryIntOp].
3use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
4
5use acir::AcirField;
6use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize};
7use num_bigint::BigUint;
8use num_traits::{CheckedDiv, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
9
10use crate::memory::{MemoryTypeError, MemoryValue};
11
12#[derive(Debug, PartialEq, thiserror::Error)]
13pub(crate) enum BrilligArithmeticError {
14    #[error("Bit size for lhs {lhs_bit_size} does not match op bit size {op_bit_size}")]
15    MismatchedLhsBitSize { lhs_bit_size: u32, op_bit_size: u32 },
16    #[error("Bit size for rhs {rhs_bit_size} does not match op bit size {op_bit_size}")]
17    MismatchedRhsBitSize { rhs_bit_size: u32, op_bit_size: u32 },
18    #[error("Attempted to divide by zero")]
19    DivisionByZero,
20}
21
22/// Evaluate a binary operation on two FieldElement memory values.
23pub(crate) fn evaluate_binary_field_op<F: AcirField>(
24    op: &BinaryFieldOp,
25    lhs: MemoryValue<F>,
26    rhs: MemoryValue<F>,
27) -> Result<MemoryValue<F>, BrilligArithmeticError> {
28    let expect_field = |value: MemoryValue<F>, make_err: fn(u32, u32) -> BrilligArithmeticError| {
29        value.expect_field().map_err(|err| {
30            if let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err {
31                make_err(value_bit_size, expected_bit_size)
32            } else {
33                unreachable!("MemoryTypeError NotInteger is only produced by to_u128")
34            }
35        })
36    };
37    let a = expect_field(lhs, |vbs, ebs| BrilligArithmeticError::MismatchedLhsBitSize {
38        lhs_bit_size: vbs,
39        op_bit_size: ebs,
40    })?;
41    let b = expect_field(rhs, |vbs, ebs| BrilligArithmeticError::MismatchedRhsBitSize {
42        rhs_bit_size: vbs,
43        op_bit_size: ebs,
44    })?;
45
46    Ok(match op {
47        // Perform addition, subtraction, multiplication, and division based on the BinaryOp variant.
48        BinaryFieldOp::Add => MemoryValue::new_field(a + b),
49        BinaryFieldOp::Sub => MemoryValue::new_field(a - b),
50        BinaryFieldOp::Mul => MemoryValue::new_field(a * b),
51        BinaryFieldOp::Div => {
52            if b.is_zero() {
53                return Err(BrilligArithmeticError::DivisionByZero);
54            } else if b.is_one() {
55                MemoryValue::new_field(a)
56            } else if b == -F::one() {
57                MemoryValue::new_field(-a)
58            } else {
59                MemoryValue::new_field(a / b)
60            }
61        }
62        BinaryFieldOp::IntegerDiv => {
63            // IntegerDiv is only meant to represent unsigned integer division.
64            // The operands must be valid non-negative integers within the field's range.
65            // Because AcirField is modulo the prime field, it does not natively track
66            // "negative" numbers as any value is already reduced modulo the prime.
67            //
68            // Therefore, we do not check for negative inputs here. It is the responsibility
69            // of the code generator to ensure that operands for IntegerDiv are valid unsigned integers.
70            // The only runtime error we check for is division by zero.
71            if b.is_zero() {
72                return Err(BrilligArithmeticError::DivisionByZero);
73            } else {
74                let a_big = BigUint::from_bytes_be(&a.to_be_bytes());
75                let b_big = BigUint::from_bytes_be(&b.to_be_bytes());
76
77                let result = a_big / b_big;
78                MemoryValue::new_field(F::from_be_bytes_reduce(&result.to_bytes_be()))
79            }
80        }
81        BinaryFieldOp::Equals => (a == b).into(),
82        BinaryFieldOp::LessThan => (a < b).into(),
83        BinaryFieldOp::LessThanEquals => (a <= b).into(),
84    })
85}
86
87/// Evaluate a binary operation on two unsigned big integers with a given bit size.
88pub(crate) fn evaluate_binary_int_op<F: AcirField>(
89    op: &BinaryIntOp,
90    lhs: MemoryValue<F>,
91    rhs: MemoryValue<F>,
92    bit_size: IntegerBitSize,
93) -> Result<MemoryValue<F>, BrilligArithmeticError> {
94    match op {
95        BinaryIntOp::Add
96        | BinaryIntOp::Sub
97        | BinaryIntOp::Mul
98        | BinaryIntOp::Div
99        | BinaryIntOp::And
100        | BinaryIntOp::Or
101        | BinaryIntOp::Xor => match (lhs, rhs, bit_size) {
102            (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
103                evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1)
104            }
105            (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
106                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8)
107            }
108            (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
109                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16)
110            }
111            (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
112                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32)
113            }
114            (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
115                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64)
116            }
117            (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
118                evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128)
119            }
120            (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
121                Err(BrilligArithmeticError::MismatchedLhsBitSize {
122                    lhs_bit_size: lhs.bit_size().to_u32::<F>(),
123                    op_bit_size: bit_size.into(),
124                })
125            }
126            (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
127                Err(BrilligArithmeticError::MismatchedRhsBitSize {
128                    rhs_bit_size: rhs.bit_size().to_u32::<F>(),
129                    op_bit_size: bit_size.into(),
130                })
131            }
132            _ => unreachable!("Invalid arguments are covered by the two arms above."),
133        },
134
135        BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
136            match (lhs, rhs, bit_size) {
137                (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
138                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
139                }
140                (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
141                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
142                }
143                (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
144                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
145                }
146                (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
147                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
148                }
149                (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
150                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
151                }
152                (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
153                    Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
154                }
155                (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
156                    Err(BrilligArithmeticError::MismatchedLhsBitSize {
157                        lhs_bit_size: lhs.bit_size().to_u32::<F>(),
158                        op_bit_size: bit_size.into(),
159                    })
160                }
161                (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
162                    Err(BrilligArithmeticError::MismatchedRhsBitSize {
163                        rhs_bit_size: rhs.bit_size().to_u32::<F>(),
164                        op_bit_size: bit_size.into(),
165                    })
166                }
167                _ => unreachable!("Invalid arguments are covered by the two arms above."),
168            }
169        }
170
171        BinaryIntOp::Shl | BinaryIntOp::Shr => match (lhs, rhs, bit_size) {
172            (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
173                if rhs {
174                    Ok(MemoryValue::U1(false))
175                } else {
176                    Ok(MemoryValue::U1(lhs))
177                }
178            }
179            (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
180                Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
181            }
182            (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
183                Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
184            }
185            (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
186                Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
187            }
188            (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
189                Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
190            }
191            (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
192                Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
193            }
194            (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
195                Err(BrilligArithmeticError::MismatchedLhsBitSize {
196                    lhs_bit_size: lhs.bit_size().to_u32::<F>(),
197                    op_bit_size: bit_size.into(),
198                })
199            }
200            (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
201                Err(BrilligArithmeticError::MismatchedRhsBitSize {
202                    rhs_bit_size: rhs.bit_size().to_u32::<F>(),
203                    op_bit_size: bit_size.into(),
204                })
205            }
206            _ => unreachable!("Invalid arguments are covered by the two arms above."),
207        },
208    }
209}
210
211/// Evaluates binary operations on 1-bit unsigned integers (booleans).
212///
213/// # Returns
214/// - Ok(result) if successful.
215/// - Err([BrilligArithmeticError::DivisionByZero]) if division by zero occurs.
216///
217/// # Panics
218/// If an operation other than `Add`, `Sub`, `Mul`, `Div`, `And`, `Or`, `Xor`, `Equals`, `LessThan`,
219/// or `LessThanEquals` is supplied as an argument.
220fn evaluate_binary_int_op_u1(
221    op: &BinaryIntOp,
222    lhs: bool,
223    rhs: bool,
224) -> Result<bool, BrilligArithmeticError> {
225    let result = match op {
226        BinaryIntOp::Equals => lhs == rhs,
227        BinaryIntOp::LessThan => !lhs & rhs,
228        BinaryIntOp::LessThanEquals => lhs <= rhs,
229        BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs,
230        BinaryIntOp::Or => lhs | rhs,
231        BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
232        BinaryIntOp::Div => {
233            if !rhs {
234                return Err(BrilligArithmeticError::DivisionByZero);
235            } else {
236                lhs
237            }
238        }
239        _ => unreachable!("Operator not handled by this function: {op:?}"),
240    };
241    Ok(result)
242}
243
244/// Evaluates comparison operations (`Equals`, `LessThan`, `LessThanEquals`)
245/// between two values of an ordered type (e.g., fields are unordered).
246///
247/// # Panics
248/// If an unsupported operator is provided (i.e., not `Equals`, `LessThan`, or `LessThanEquals`).
249fn evaluate_binary_int_op_cmp<T: Ord + PartialEq>(op: &BinaryIntOp, lhs: T, rhs: T) -> bool {
250    match op {
251        BinaryIntOp::Equals => lhs == rhs,
252        BinaryIntOp::LessThan => lhs < rhs,
253        BinaryIntOp::LessThanEquals => lhs <= rhs,
254        _ => unreachable!("Operator not handled by this function: {op:?}"),
255    }
256}
257
258/// Evaluates shift operations (`Shl`, `Shr`) for unsigned integers.
259/// Ensures that shifting beyond the type width returns zero.
260///
261/// # Returns
262/// - Ok(result)
263///
264/// # Panics
265/// If an unsupported operator is provided (i.e., not `Shl` or `Shr`).
266fn evaluate_binary_int_op_shifts<T: ToPrimitive + Zero + Shl<Output = T> + Shr<Output = T>>(
267    op: &BinaryIntOp,
268    lhs: T,
269    rhs: T,
270) -> Result<T, BrilligArithmeticError> {
271    let bit_size = (size_of::<T>() * 8) as u128;
272    let rhs_val = rhs.to_u128().unwrap();
273    if rhs_val >= bit_size {
274        return Ok(T::zero());
275    }
276    match op {
277        BinaryIntOp::Shl => Ok(lhs << rhs),
278        BinaryIntOp::Shr => Ok(lhs >> rhs),
279        _ => unreachable!("Operator not handled by this function: {op:?}"),
280    }
281}
282
283/// Evaluates arithmetic or bitwise operations on unsigned integer types,
284/// using wrapping arithmetic for [add][BinaryIntOp::Add], [sub][BinaryIntOp::Sub], and [mul][BinaryIntOp::Mul].
285///
286/// # Returns
287/// - Ok(result) if successful.
288/// - Err([BrilligArithmeticError::DivisionByZero]) if division by zero occurs.
289///
290/// # Panics
291/// If there an operation other than Add, Sub, Mul, Div, And, Or, Xor is supplied as an argument.
292fn evaluate_binary_int_op_arith<
293    T: WrappingAdd
294        + WrappingSub
295        + WrappingMul
296        + CheckedDiv
297        + BitAnd<Output = T>
298        + BitOr<Output = T>
299        + BitXor<Output = T>,
300>(
301    op: &BinaryIntOp,
302    lhs: T,
303    rhs: T,
304) -> Result<T, BrilligArithmeticError> {
305    let result = match op {
306        BinaryIntOp::Add => lhs.wrapping_add(&rhs),
307        BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
308        BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
309        BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
310        BinaryIntOp::And => lhs & rhs,
311        BinaryIntOp::Or => lhs | rhs,
312        BinaryIntOp::Xor => lhs ^ rhs,
313        _ => unreachable!("Operator not handled by this function: {op:?}"),
314    };
315    Ok(result)
316}
317
318#[cfg(test)]
319mod int_ops {
320    use super::*;
321    use acir::{AcirField, FieldElement};
322
323    struct TestParams {
324        a: u128,
325        b: u128,
326        result: u128,
327    }
328
329    fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: IntegerBitSize) -> u128 {
330        let result_value: MemoryValue<FieldElement> = evaluate_binary_int_op(
331            op,
332            MemoryValue::new_integer(a, bit_size),
333            MemoryValue::new_integer(b, bit_size),
334            bit_size,
335        )
336        .unwrap();
337        // Convert back to u128
338        result_value.to_field().to_u128()
339    }
340
341    fn to_negative(a: u128, bit_size: IntegerBitSize) -> u128 {
342        assert!(a > 0);
343        if bit_size == IntegerBitSize::U128 {
344            0_u128.wrapping_sub(a)
345        } else {
346            let two_pow = 2_u128.pow(bit_size.into());
347            two_pow - a
348        }
349    }
350
351    fn evaluate_int_ops(test_params: Vec<TestParams>, op: BinaryIntOp, bit_size: IntegerBitSize) {
352        for test in test_params {
353            assert_eq!(evaluate_u128(&op, test.a, test.b, bit_size), test.result);
354        }
355    }
356
357    #[test]
358    fn add_test() {
359        let bit_size = IntegerBitSize::U8;
360
361        let test_ops = vec![
362            TestParams { a: 50, b: 100, result: 150 },
363            TestParams { a: 250, b: 10, result: 4 },
364            TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
365            TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
366            TestParams { a: 5, b: to_negative(6, bit_size), result: to_negative(1, bit_size) },
367        ];
368        evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
369
370        let bit_size = IntegerBitSize::U128;
371        let test_ops = vec![
372            TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
373            TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
374        ];
375
376        evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
377
378        // Mismatched bit sizes should error
379        assert_eq!(
380            evaluate_binary_int_op(
381                &BinaryIntOp::Add,
382                MemoryValue::<FieldElement>::U8(1),
383                MemoryValue::<FieldElement>::U16(2),
384                IntegerBitSize::U8
385            ),
386            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
387        );
388        assert_eq!(
389            evaluate_binary_int_op(
390                &BinaryIntOp::Add,
391                MemoryValue::<FieldElement>::U16(2),
392                MemoryValue::<FieldElement>::U8(1),
393                IntegerBitSize::U8
394            ),
395            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
396        );
397    }
398
399    #[test]
400    fn sub_test() {
401        let bit_size = IntegerBitSize::U8;
402
403        let test_ops = vec![
404            TestParams { a: 50, b: 30, result: 20 },
405            TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
406            TestParams { a: 5, b: to_negative(3, bit_size), result: 8 },
407            TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
408            TestParams { a: 254, b: to_negative(3, bit_size), result: 1 },
409        ];
410        evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
411
412        let bit_size = IntegerBitSize::U128;
413
414        let test_ops = vec![
415            TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
416            TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
417        ];
418        evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
419
420        // Mismatched bit sizes should error
421        assert_eq!(
422            evaluate_binary_int_op(
423                &BinaryIntOp::Sub,
424                MemoryValue::<FieldElement>::U8(1),
425                MemoryValue::<FieldElement>::U16(1),
426                IntegerBitSize::U8
427            ),
428            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
429        );
430        assert_eq!(
431            evaluate_binary_int_op(
432                &BinaryIntOp::Sub,
433                MemoryValue::<FieldElement>::U16(1),
434                MemoryValue::<FieldElement>::U8(1),
435                IntegerBitSize::U8
436            ),
437            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
438        );
439    }
440
441    #[test]
442    fn mul_test() {
443        let bit_size = IntegerBitSize::U8;
444
445        let test_ops = vec![
446            TestParams { a: 5, b: 3, result: 15 },
447            TestParams { a: 5, b: 100, result: 244 },
448            TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
449            TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
450            TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
451        ];
452
453        evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
454
455        let bit_size = IntegerBitSize::U64;
456        let a = 2_u128.pow(bit_size.into()) - 1;
457        let b = 3;
458
459        // ( 2**(n-1) - 1 ) * 3 = 2*2**(n-1) - 2 + (2**(n-1) - 1) => wraps to (2**(n-1) - 1) - 2
460        assert_eq!(evaluate_u128(&BinaryIntOp::Mul, a, b, bit_size), a - 2);
461
462        let bit_size = IntegerBitSize::U128;
463
464        let test_ops = vec![
465            TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
466            TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
467            TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
468        ];
469
470        evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
471
472        // Mismatched bit sizes should error
473        assert_eq!(
474            evaluate_binary_int_op(
475                &BinaryIntOp::Mul,
476                MemoryValue::<FieldElement>::U8(1),
477                MemoryValue::<FieldElement>::U16(1),
478                IntegerBitSize::U8
479            ),
480            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
481        );
482        assert_eq!(
483            evaluate_binary_int_op(
484                &BinaryIntOp::Mul,
485                MemoryValue::<FieldElement>::U16(1),
486                MemoryValue::<FieldElement>::U8(1),
487                IntegerBitSize::U8
488            ),
489            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
490        );
491    }
492
493    #[test]
494    fn div_test() {
495        let bit_size = IntegerBitSize::U8;
496
497        let test_ops =
498            vec![TestParams { a: 5, b: 3, result: 1 }, TestParams { a: 5, b: 10, result: 0 }];
499
500        evaluate_int_ops(test_ops, BinaryIntOp::Div, bit_size);
501
502        // Mismatched bit sizes should error
503        assert_eq!(
504            evaluate_binary_int_op(
505                &BinaryIntOp::Div,
506                MemoryValue::<FieldElement>::U8(1),
507                MemoryValue::<FieldElement>::U16(1),
508                IntegerBitSize::U8
509            ),
510            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
511        );
512        assert_eq!(
513            evaluate_binary_int_op(
514                &BinaryIntOp::Div,
515                MemoryValue::<FieldElement>::U16(1),
516                MemoryValue::<FieldElement>::U8(1),
517                IntegerBitSize::U8
518            ),
519            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
520        );
521
522        // Division by zero should error
523        assert_eq!(
524            evaluate_binary_int_op(
525                &BinaryIntOp::Div,
526                MemoryValue::<FieldElement>::U8(1),
527                MemoryValue::<FieldElement>::U8(0),
528                IntegerBitSize::U8
529            ),
530            Err(BrilligArithmeticError::DivisionByZero)
531        );
532    }
533
534    #[test]
535    fn shl_test() {
536        let bit_size = IntegerBitSize::U8;
537
538        let test_ops =
539            vec![TestParams { a: 1, b: 7, result: 128 }, TestParams { a: 5, b: 7, result: 128 }];
540
541        evaluate_int_ops(test_ops, BinaryIntOp::Shl, bit_size);
542        // Shifting more than bit width returns zero
543        assert_eq!(
544            evaluate_binary_int_op(
545                &BinaryIntOp::Shl,
546                MemoryValue::<FieldElement>::U8(1u8),
547                MemoryValue::<FieldElement>::U8(8u8),
548                IntegerBitSize::U8
549            ),
550            Ok(MemoryValue::<FieldElement>::U8(0u8))
551        );
552        // Both LHS and RHS has to match the operation bit size.
553        assert_eq!(
554            evaluate_binary_int_op(
555                &BinaryIntOp::Shr,
556                MemoryValue::<FieldElement>::U16(1),
557                MemoryValue::<FieldElement>::U8(1),
558                IntegerBitSize::U8
559            ),
560            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
561        );
562        assert_eq!(
563            evaluate_binary_int_op(
564                &BinaryIntOp::Shr,
565                MemoryValue::<FieldElement>::U8(1),
566                MemoryValue::<FieldElement>::U16(1),
567                IntegerBitSize::U8
568            ),
569            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
570        );
571    }
572
573    #[test]
574    fn shr_test() {
575        let bit_size = IntegerBitSize::U8;
576
577        let test_ops =
578            vec![TestParams { a: 1, b: 0, result: 1 }, TestParams { a: 5, b: 1, result: 2 }];
579
580        evaluate_int_ops(test_ops, BinaryIntOp::Shr, bit_size);
581        // Shifting more than bit width returns zero
582        assert_eq!(
583            evaluate_binary_int_op(
584                &BinaryIntOp::Shr,
585                MemoryValue::<FieldElement>::U8(1u8),
586                MemoryValue::<FieldElement>::U8(8u8),
587                IntegerBitSize::U8
588            ),
589            Ok(MemoryValue::<FieldElement>::U8(0u8))
590        );
591        // Both LHS and RHS has to match the operation bit size.
592        assert_eq!(
593            evaluate_binary_int_op(
594                &BinaryIntOp::Shr,
595                MemoryValue::<FieldElement>::U16(1),
596                MemoryValue::<FieldElement>::U8(1),
597                IntegerBitSize::U8
598            ),
599            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
600        );
601        assert_eq!(
602            evaluate_binary_int_op(
603                &BinaryIntOp::Shr,
604                MemoryValue::<FieldElement>::U8(1),
605                MemoryValue::<FieldElement>::U16(1),
606                IntegerBitSize::U8
607            ),
608            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
609        );
610    }
611
612    /// Reference implementation: compute `lhs op rhs` using u128 arithmetic,
613    /// then truncate to the given bit width.
614    fn reference_shift(op: &BinaryIntOp, lhs: u128, rhs: u128, bit_width: u32) -> u128 {
615        let mask = if bit_width == 128 { u128::MAX } else { (1u128 << bit_width) - 1 };
616        let lhs = lhs & mask;
617        let rhs = rhs & mask;
618        if rhs >= u128::from(bit_width) {
619            return 0;
620        }
621        let result = match op {
622            BinaryIntOp::Shl => lhs << rhs as u32,
623            BinaryIntOp::Shr => lhs >> rhs as u32,
624            _ => unreachable!(),
625        };
626        result & mask
627    }
628
629    proptest::proptest! {
630        #[test]
631        fn shift_u8_fuzz(lhs in 0u128..=u128::from(u8::MAX), rhs in 0u128..=u128::from(u8::MAX)) {
632            let bit_size = IntegerBitSize::U8;
633            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
634                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
635                let expected = reference_shift(&op, lhs, rhs, 8);
636                proptest::prop_assert_eq!(actual, expected);
637            }
638        }
639
640        #[test]
641        fn shift_u16_fuzz(lhs in 0u128..=u128::from(u16::MAX), rhs in 0u128..=u128::from(u16::MAX)) {
642            let bit_size = IntegerBitSize::U16;
643            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
644                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
645                let expected = reference_shift(&op, lhs, rhs, 16);
646                proptest::prop_assert_eq!(actual, expected);
647            }
648        }
649
650        #[test]
651        fn shift_u32_fuzz(lhs in 0u128..=u128::from(u32::MAX), rhs in 0u128..=u128::from(u32::MAX)) {
652            let bit_size = IntegerBitSize::U32;
653            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
654                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
655                let expected = reference_shift(&op, lhs, rhs, 32);
656                proptest::prop_assert_eq!(actual, expected);
657            }
658        }
659
660        #[test]
661        fn shift_u64_fuzz(lhs in 0u128..=u128::from(u64::MAX), rhs in 0u128..=u128::from(u64::MAX)) {
662            let bit_size = IntegerBitSize::U64;
663            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
664                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
665                let expected = reference_shift(&op, lhs, rhs, 64);
666                proptest::prop_assert_eq!(actual, expected);
667            }
668        }
669
670        #[test]
671        fn shift_u128_fuzz(lhs: u128, rhs: u128) {
672            let bit_size = IntegerBitSize::U128;
673            for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
674                let actual = evaluate_u128(&op, lhs, rhs, bit_size);
675                let expected = reference_shift(&op, lhs, rhs, 128);
676                proptest::prop_assert_eq!(actual, expected);
677            }
678        }
679    }
680
681    #[test]
682    fn comparison_ops_test() {
683        let bit_size = IntegerBitSize::U8;
684
685        // Equals
686        let test_ops = vec![
687            TestParams { a: 5, b: 5, result: 1 },
688            TestParams { a: 10, b: 5, result: 0 },
689            TestParams { a: 0, b: 0, result: 1 },
690        ];
691        evaluate_int_ops(test_ops, BinaryIntOp::Equals, bit_size);
692
693        // LessThan
694        let test_ops = vec![
695            TestParams { a: 4, b: 5, result: 1 },
696            TestParams { a: 5, b: 4, result: 0 },
697            TestParams { a: 5, b: 5, result: 0 },
698        ];
699        evaluate_int_ops(test_ops, BinaryIntOp::LessThan, bit_size);
700
701        // LessThanEquals
702        let test_ops = vec![
703            TestParams { a: 4, b: 5, result: 1 },
704            TestParams { a: 5, b: 4, result: 0 },
705            TestParams { a: 5, b: 5, result: 1 },
706        ];
707        evaluate_int_ops(test_ops, BinaryIntOp::LessThanEquals, bit_size);
708
709        // Mismatched bit sizes should error
710        assert_eq!(
711            evaluate_binary_int_op(
712                &BinaryIntOp::Equals,
713                MemoryValue::<FieldElement>::U8(1),
714                MemoryValue::<FieldElement>::U16(1),
715                IntegerBitSize::U8
716            ),
717            Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
718        );
719        assert_eq!(
720            evaluate_binary_int_op(
721                &BinaryIntOp::LessThan,
722                MemoryValue::<FieldElement>::U16(1),
723                MemoryValue::<FieldElement>::U8(1),
724                IntegerBitSize::U8
725            ),
726            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
727        );
728        assert_eq!(
729            evaluate_binary_int_op(
730                &BinaryIntOp::LessThanEquals,
731                MemoryValue::<FieldElement>::U16(1),
732                MemoryValue::<FieldElement>::U8(1),
733                IntegerBitSize::U8
734            ),
735            Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
736        );
737    }
738}
739
740#[cfg(test)]
741mod field_ops {
742    use super::*;
743    use acir::{AcirField, FieldElement};
744
745    struct TestParams {
746        a: FieldElement,
747        b: FieldElement,
748        result: FieldElement,
749    }
750
751    fn evaluate_field_u128(op: &BinaryFieldOp, a: FieldElement, b: FieldElement) -> FieldElement {
752        let result_value: MemoryValue<FieldElement> =
753            evaluate_binary_field_op(op, MemoryValue::new_field(a), MemoryValue::new_field(b))
754                .unwrap();
755        // Convert back to FieldElement
756        result_value.to_field()
757    }
758
759    fn evaluate_field_ops(test_params: Vec<TestParams>, op: BinaryFieldOp) {
760        for test in test_params {
761            assert_eq!(evaluate_field_u128(&op, test.a, test.b), test.result);
762        }
763    }
764
765    #[test]
766    fn add_test() {
767        let test_ops = vec![
768            TestParams { a: 1u32.into(), b: 2u32.into(), result: 3u32.into() },
769            TestParams { a: 5u32.into(), b: 10u32.into(), result: 15u32.into() },
770            TestParams { a: 250u32.into(), b: 10u32.into(), result: 260u32.into() },
771        ];
772        evaluate_field_ops(test_ops, BinaryFieldOp::Add);
773
774        // Mismatched bit sizes
775        assert_eq!(
776            evaluate_binary_field_op(
777                &BinaryFieldOp::Add,
778                MemoryValue::new_field(FieldElement::from(1u32)),
779                MemoryValue::<FieldElement>::U16(2),
780            ),
781            Err(BrilligArithmeticError::MismatchedRhsBitSize {
782                rhs_bit_size: 16,
783                op_bit_size: 254
784            })
785        );
786
787        assert_eq!(
788            evaluate_binary_field_op(
789                &BinaryFieldOp::Add,
790                MemoryValue::<FieldElement>::U16(1),
791                MemoryValue::new_field(FieldElement::from(2u32)),
792            ),
793            Err(BrilligArithmeticError::MismatchedLhsBitSize {
794                lhs_bit_size: 16,
795                op_bit_size: 254
796            })
797        );
798    }
799
800    #[test]
801    fn sub_test() {
802        let test_ops = vec![
803            TestParams { a: 5u32.into(), b: 3u32.into(), result: 2u32.into() },
804            TestParams { a: 2u32.into(), b: 10u32.into(), result: FieldElement::from(-8_i128) },
805        ];
806        evaluate_field_ops(test_ops, BinaryFieldOp::Sub);
807
808        // Mismatched bit sizes
809        assert_eq!(
810            evaluate_binary_field_op(
811                &BinaryFieldOp::Sub,
812                MemoryValue::new_field(FieldElement::from(1u32)),
813                MemoryValue::<FieldElement>::U16(1),
814            ),
815            Err(BrilligArithmeticError::MismatchedRhsBitSize {
816                rhs_bit_size: 16,
817                op_bit_size: 254
818            })
819        );
820
821        assert_eq!(
822            evaluate_binary_field_op(
823                &BinaryFieldOp::Sub,
824                MemoryValue::<FieldElement>::U16(1),
825                MemoryValue::new_field(FieldElement::from(1u32)),
826            ),
827            Err(BrilligArithmeticError::MismatchedLhsBitSize {
828                lhs_bit_size: 16,
829                op_bit_size: 254
830            })
831        );
832    }
833
834    #[test]
835    fn mul_test() {
836        let test_ops = vec![
837            TestParams { a: 2u32.into(), b: 3u32.into(), result: 6u32.into() },
838            TestParams { a: 10u32.into(), b: 25u32.into(), result: 250u32.into() },
839        ];
840        evaluate_field_ops(test_ops, BinaryFieldOp::Mul);
841
842        // Mismatched bit sizes
843        assert_eq!(
844            evaluate_binary_field_op(
845                &BinaryFieldOp::Mul,
846                MemoryValue::new_field(FieldElement::from(1u32)),
847                MemoryValue::<FieldElement>::U16(1),
848            ),
849            Err(BrilligArithmeticError::MismatchedRhsBitSize {
850                rhs_bit_size: 16,
851                op_bit_size: 254
852            })
853        );
854
855        assert_eq!(
856            evaluate_binary_field_op(
857                &BinaryFieldOp::Mul,
858                MemoryValue::<FieldElement>::U16(1),
859                MemoryValue::new_field(FieldElement::from(1u32)),
860            ),
861            Err(BrilligArithmeticError::MismatchedLhsBitSize {
862                lhs_bit_size: 16,
863                op_bit_size: 254
864            })
865        );
866    }
867
868    #[test]
869    fn div_test() {
870        let test_ops = vec![
871            TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
872            TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
873            TestParams {
874                a: 10u32.into(),
875                b: FieldElement::from(-1_i128),
876                result: FieldElement::from(-10_i128),
877            },
878            TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
879            // Field division is a * 1/b. The inverse of 20 is 7660885005143746327786242010840046280991927540145612020294371465301532973466
880            TestParams {
881                a: 10u32.into(),
882                b: 20u32.into(),
883                result: FieldElement::try_from_str(
884                    "10944121435919637611123202872628637544274182200208017171849102093287904247809",
885                )
886                .unwrap(),
887            },
888            // The inverse of 7 is 3126891838834182174606629392179610726935480628630862049099743455225115499374.
889            TestParams {
890                a: 100u32.into(),
891                b: 7u32.into(),
892                result: FieldElement::try_from_str(
893                    "6253783677668364349213258784359221453870961257261724098199486910450230998762",
894                )
895                .unwrap(),
896            },
897        ];
898        evaluate_field_ops(test_ops, BinaryFieldOp::Div);
899
900        // Division by zero
901        assert_eq!(
902            evaluate_binary_field_op(
903                &BinaryFieldOp::Div,
904                MemoryValue::new_field(FieldElement::from(1u128)),
905                MemoryValue::new_field(FieldElement::zero()),
906            ),
907            Err(BrilligArithmeticError::DivisionByZero)
908        );
909
910        // Mismatched bit sizes
911        assert_eq!(
912            evaluate_binary_field_op(
913                &BinaryFieldOp::Div,
914                MemoryValue::new_field(FieldElement::from(1u32)),
915                MemoryValue::<FieldElement>::U16(1),
916            ),
917            Err(BrilligArithmeticError::MismatchedRhsBitSize {
918                rhs_bit_size: 16,
919                op_bit_size: 254
920            })
921        );
922
923        assert_eq!(
924            evaluate_binary_field_op(
925                &BinaryFieldOp::Div,
926                MemoryValue::<FieldElement>::U16(1),
927                MemoryValue::new_field(FieldElement::from(1u32)),
928            ),
929            Err(BrilligArithmeticError::MismatchedLhsBitSize {
930                lhs_bit_size: 16,
931                op_bit_size: 254
932            })
933        );
934    }
935
936    #[test]
937    fn integer_div_test() {
938        let test_ops = vec![
939            TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
940            TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
941            // Negative numbers are treated as large unsigned numbers, thus we expect a result of 0 here
942            TestParams { a: 10u32.into(), b: FieldElement::from(-1_i128), result: 0u32.into() },
943            TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
944            TestParams { a: 10u32.into(), b: 20u32.into(), result: 0u32.into() },
945            // 100 / 7 == 14 with a remainder of 2. The remainder is discarded.
946            TestParams { a: 100u32.into(), b: 7u32.into(), result: 14u32.into() },
947        ];
948        evaluate_field_ops(test_ops, BinaryFieldOp::IntegerDiv);
949
950        // Division by zero should error
951        assert_eq!(
952            evaluate_binary_field_op(
953                &BinaryFieldOp::IntegerDiv,
954                MemoryValue::new_field(FieldElement::from(1u128)),
955                MemoryValue::new_field(FieldElement::zero()),
956            ),
957            Err(BrilligArithmeticError::DivisionByZero)
958        );
959
960        // Mismatched bit sizes should error
961        assert_eq!(
962            evaluate_binary_field_op(
963                &BinaryFieldOp::IntegerDiv,
964                MemoryValue::new_field(FieldElement::from(1u32)),
965                MemoryValue::<FieldElement>::U16(1),
966            ),
967            Err(BrilligArithmeticError::MismatchedRhsBitSize {
968                rhs_bit_size: 16,
969                op_bit_size: 254
970            })
971        );
972
973        assert_eq!(
974            evaluate_binary_field_op(
975                &BinaryFieldOp::IntegerDiv,
976                MemoryValue::<FieldElement>::U16(1),
977                MemoryValue::new_field(FieldElement::from(1u32)),
978            ),
979            Err(BrilligArithmeticError::MismatchedLhsBitSize {
980                lhs_bit_size: 16,
981                op_bit_size: 254
982            })
983        );
984    }
985
986    #[test]
987    fn comparison_ops_test() {
988        // Equals
989        let test_ops = vec![
990            TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
991            TestParams { a: 10u32.into(), b: 5u32.into(), result: 0u32.into() },
992            TestParams { a: 0u32.into(), b: 0u32.into(), result: 1u32.into() },
993        ];
994        evaluate_field_ops(test_ops, BinaryFieldOp::Equals);
995
996        // LessThan
997        let test_ops = vec![
998            TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
999            TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
1000            TestParams { a: 5u32.into(), b: 5u32.into(), result: 0u32.into() },
1001        ];
1002        evaluate_field_ops(test_ops, BinaryFieldOp::LessThan);
1003
1004        // LessThanEquals
1005        let test_ops = vec![
1006            TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
1007            TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
1008            TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
1009        ];
1010        evaluate_field_ops(test_ops, BinaryFieldOp::LessThanEquals);
1011
1012        // Mismatched bit sizes should error
1013        assert_eq!(
1014            evaluate_binary_field_op(
1015                &BinaryFieldOp::Equals,
1016                MemoryValue::new_field(1u32.into()),
1017                MemoryValue::<FieldElement>::U16(1),
1018            ),
1019            Err(BrilligArithmeticError::MismatchedRhsBitSize {
1020                rhs_bit_size: 16,
1021                op_bit_size: 254
1022            })
1023        );
1024
1025        assert_eq!(
1026            evaluate_binary_field_op(
1027                &BinaryFieldOp::LessThan,
1028                MemoryValue::<FieldElement>::U16(1),
1029                MemoryValue::new_field(1u32.into()),
1030            ),
1031            Err(BrilligArithmeticError::MismatchedLhsBitSize {
1032                lhs_bit_size: 16,
1033                op_bit_size: 254
1034            })
1035        );
1036
1037        assert_eq!(
1038            evaluate_binary_field_op(
1039                &BinaryFieldOp::LessThanEquals,
1040                MemoryValue::<FieldElement>::U16(1),
1041                MemoryValue::new_field(1u32.into()),
1042            ),
1043            Err(BrilligArithmeticError::MismatchedLhsBitSize {
1044                lhs_bit_size: 16,
1045                op_bit_size: 254
1046            })
1047        );
1048    }
1049}