1use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
4
5use acir::AcirField;
6use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize};
7use num_bigint::BigUint;
8use num_traits::{CheckedDiv, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
9
10use crate::memory::{MemoryTypeError, MemoryValue};
11
12#[derive(Debug, PartialEq, thiserror::Error)]
13pub(crate) enum BrilligArithmeticError {
14 #[error("Bit size for lhs {lhs_bit_size} does not match op bit size {op_bit_size}")]
15 MismatchedLhsBitSize { lhs_bit_size: u32, op_bit_size: u32 },
16 #[error("Bit size for rhs {rhs_bit_size} does not match op bit size {op_bit_size}")]
17 MismatchedRhsBitSize { rhs_bit_size: u32, op_bit_size: u32 },
18 #[error("Attempted to divide by zero")]
19 DivisionByZero,
20}
21
22pub(crate) fn evaluate_binary_field_op<F: AcirField>(
24 op: &BinaryFieldOp,
25 lhs: MemoryValue<F>,
26 rhs: MemoryValue<F>,
27) -> Result<MemoryValue<F>, BrilligArithmeticError> {
28 let expect_field = |value: MemoryValue<F>, make_err: fn(u32, u32) -> BrilligArithmeticError| {
29 value.expect_field().map_err(|err| {
30 if let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err {
31 make_err(value_bit_size, expected_bit_size)
32 } else {
33 unreachable!("MemoryTypeError NotInteger is only produced by to_u128")
34 }
35 })
36 };
37 let a = expect_field(lhs, |vbs, ebs| BrilligArithmeticError::MismatchedLhsBitSize {
38 lhs_bit_size: vbs,
39 op_bit_size: ebs,
40 })?;
41 let b = expect_field(rhs, |vbs, ebs| BrilligArithmeticError::MismatchedRhsBitSize {
42 rhs_bit_size: vbs,
43 op_bit_size: ebs,
44 })?;
45
46 Ok(match op {
47 BinaryFieldOp::Add => MemoryValue::new_field(a + b),
49 BinaryFieldOp::Sub => MemoryValue::new_field(a - b),
50 BinaryFieldOp::Mul => MemoryValue::new_field(a * b),
51 BinaryFieldOp::Div => {
52 if b.is_zero() {
53 return Err(BrilligArithmeticError::DivisionByZero);
54 } else if b.is_one() {
55 MemoryValue::new_field(a)
56 } else if b == -F::one() {
57 MemoryValue::new_field(-a)
58 } else {
59 MemoryValue::new_field(a / b)
60 }
61 }
62 BinaryFieldOp::IntegerDiv => {
63 if b.is_zero() {
72 return Err(BrilligArithmeticError::DivisionByZero);
73 } else {
74 let a_big = BigUint::from_bytes_be(&a.to_be_bytes());
75 let b_big = BigUint::from_bytes_be(&b.to_be_bytes());
76
77 let result = a_big / b_big;
78 MemoryValue::new_field(F::from_be_bytes_reduce(&result.to_bytes_be()))
79 }
80 }
81 BinaryFieldOp::Equals => (a == b).into(),
82 BinaryFieldOp::LessThan => (a < b).into(),
83 BinaryFieldOp::LessThanEquals => (a <= b).into(),
84 })
85}
86
87pub(crate) fn evaluate_binary_int_op<F: AcirField>(
89 op: &BinaryIntOp,
90 lhs: MemoryValue<F>,
91 rhs: MemoryValue<F>,
92 bit_size: IntegerBitSize,
93) -> Result<MemoryValue<F>, BrilligArithmeticError> {
94 match op {
95 BinaryIntOp::Add
96 | BinaryIntOp::Sub
97 | BinaryIntOp::Mul
98 | BinaryIntOp::Div
99 | BinaryIntOp::And
100 | BinaryIntOp::Or
101 | BinaryIntOp::Xor => match (lhs, rhs, bit_size) {
102 (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
103 evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1)
104 }
105 (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
106 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8)
107 }
108 (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
109 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16)
110 }
111 (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
112 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32)
113 }
114 (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
115 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64)
116 }
117 (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
118 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128)
119 }
120 (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
121 Err(BrilligArithmeticError::MismatchedLhsBitSize {
122 lhs_bit_size: lhs.bit_size().to_u32::<F>(),
123 op_bit_size: bit_size.into(),
124 })
125 }
126 (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
127 Err(BrilligArithmeticError::MismatchedRhsBitSize {
128 rhs_bit_size: rhs.bit_size().to_u32::<F>(),
129 op_bit_size: bit_size.into(),
130 })
131 }
132 _ => unreachable!("Invalid arguments are covered by the two arms above."),
133 },
134
135 BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
136 match (lhs, rhs, bit_size) {
137 (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
138 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
139 }
140 (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
141 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
142 }
143 (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
144 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
145 }
146 (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
147 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
148 }
149 (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
150 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
151 }
152 (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
153 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
154 }
155 (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
156 Err(BrilligArithmeticError::MismatchedLhsBitSize {
157 lhs_bit_size: lhs.bit_size().to_u32::<F>(),
158 op_bit_size: bit_size.into(),
159 })
160 }
161 (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
162 Err(BrilligArithmeticError::MismatchedRhsBitSize {
163 rhs_bit_size: rhs.bit_size().to_u32::<F>(),
164 op_bit_size: bit_size.into(),
165 })
166 }
167 _ => unreachable!("Invalid arguments are covered by the two arms above."),
168 }
169 }
170
171 BinaryIntOp::Shl | BinaryIntOp::Shr => match (lhs, rhs, bit_size) {
172 (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
173 if rhs {
174 Ok(MemoryValue::U1(false))
175 } else {
176 Ok(MemoryValue::U1(lhs))
177 }
178 }
179 (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
180 Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
181 }
182 (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
183 Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
184 }
185 (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
186 Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
187 }
188 (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
189 Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
190 }
191 (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
192 Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
193 }
194 _ => Err(BrilligArithmeticError::MismatchedLhsBitSize {
195 lhs_bit_size: lhs.bit_size().to_u32::<F>(),
196 op_bit_size: bit_size.into(),
197 }),
198 },
199 }
200}
201
202fn evaluate_binary_int_op_u1(
212 op: &BinaryIntOp,
213 lhs: bool,
214 rhs: bool,
215) -> Result<bool, BrilligArithmeticError> {
216 let result = match op {
217 BinaryIntOp::Equals => lhs == rhs,
218 BinaryIntOp::LessThan => !lhs & rhs,
219 BinaryIntOp::LessThanEquals => lhs <= rhs,
220 BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs,
221 BinaryIntOp::Or => lhs | rhs,
222 BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
223 BinaryIntOp::Div => {
224 if !rhs {
225 return Err(BrilligArithmeticError::DivisionByZero);
226 } else {
227 lhs
228 }
229 }
230 _ => unreachable!("Operator not handled by this function: {op:?}"),
231 };
232 Ok(result)
233}
234
235fn evaluate_binary_int_op_cmp<T: Ord + PartialEq>(op: &BinaryIntOp, lhs: T, rhs: T) -> bool {
241 match op {
242 BinaryIntOp::Equals => lhs == rhs,
243 BinaryIntOp::LessThan => lhs < rhs,
244 BinaryIntOp::LessThanEquals => lhs <= rhs,
245 _ => unreachable!("Operator not handled by this function: {op:?}"),
246 }
247}
248
249fn evaluate_binary_int_op_shifts<T: ToPrimitive + Zero + Shl<Output = T> + Shr<Output = T>>(
258 op: &BinaryIntOp,
259 lhs: T,
260 rhs: T,
261) -> Result<T, BrilligArithmeticError> {
262 let bit_size = (size_of::<T>() * 8) as u128;
263 let rhs_val = rhs.to_u128().unwrap();
264 if rhs_val >= bit_size {
265 return Ok(T::zero());
266 }
267 match op {
268 BinaryIntOp::Shl => Ok(lhs << rhs),
269 BinaryIntOp::Shr => Ok(lhs >> rhs),
270 _ => unreachable!("Operator not handled by this function: {op:?}"),
271 }
272}
273
274fn evaluate_binary_int_op_arith<
284 T: WrappingAdd
285 + WrappingSub
286 + WrappingMul
287 + CheckedDiv
288 + BitAnd<Output = T>
289 + BitOr<Output = T>
290 + BitXor<Output = T>,
291>(
292 op: &BinaryIntOp,
293 lhs: T,
294 rhs: T,
295) -> Result<T, BrilligArithmeticError> {
296 let result = match op {
297 BinaryIntOp::Add => lhs.wrapping_add(&rhs),
298 BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
299 BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
300 BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
301 BinaryIntOp::And => lhs & rhs,
302 BinaryIntOp::Or => lhs | rhs,
303 BinaryIntOp::Xor => lhs ^ rhs,
304 _ => unreachable!("Operator not handled by this function: {op:?}"),
305 };
306 Ok(result)
307}
308
309#[cfg(test)]
310mod int_ops {
311 use super::*;
312 use acir::{AcirField, FieldElement};
313
314 struct TestParams {
315 a: u128,
316 b: u128,
317 result: u128,
318 }
319
320 fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: IntegerBitSize) -> u128 {
321 let result_value: MemoryValue<FieldElement> = evaluate_binary_int_op(
322 op,
323 MemoryValue::new_integer(a, bit_size),
324 MemoryValue::new_integer(b, bit_size),
325 bit_size,
326 )
327 .unwrap();
328 result_value.to_field().to_u128()
330 }
331
332 fn to_negative(a: u128, bit_size: IntegerBitSize) -> u128 {
333 assert!(a > 0);
334 if bit_size == IntegerBitSize::U128 {
335 0_u128.wrapping_sub(a)
336 } else {
337 let two_pow = 2_u128.pow(bit_size.into());
338 two_pow - a
339 }
340 }
341
342 fn evaluate_int_ops(test_params: Vec<TestParams>, op: BinaryIntOp, bit_size: IntegerBitSize) {
343 for test in test_params {
344 assert_eq!(evaluate_u128(&op, test.a, test.b, bit_size), test.result);
345 }
346 }
347
348 #[test]
349 fn add_test() {
350 let bit_size = IntegerBitSize::U8;
351
352 let test_ops = vec![
353 TestParams { a: 50, b: 100, result: 150 },
354 TestParams { a: 250, b: 10, result: 4 },
355 TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
356 TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
357 TestParams { a: 5, b: to_negative(6, bit_size), result: to_negative(1, bit_size) },
358 ];
359 evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
360
361 let bit_size = IntegerBitSize::U128;
362 let test_ops = vec![
363 TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
364 TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
365 ];
366
367 evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
368
369 assert_eq!(
371 evaluate_binary_int_op(
372 &BinaryIntOp::Add,
373 MemoryValue::<FieldElement>::U8(1),
374 MemoryValue::<FieldElement>::U16(2),
375 IntegerBitSize::U8
376 ),
377 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
378 );
379 assert_eq!(
380 evaluate_binary_int_op(
381 &BinaryIntOp::Add,
382 MemoryValue::<FieldElement>::U16(2),
383 MemoryValue::<FieldElement>::U8(1),
384 IntegerBitSize::U8
385 ),
386 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
387 );
388 }
389
390 #[test]
391 fn sub_test() {
392 let bit_size = IntegerBitSize::U8;
393
394 let test_ops = vec![
395 TestParams { a: 50, b: 30, result: 20 },
396 TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
397 TestParams { a: 5, b: to_negative(3, bit_size), result: 8 },
398 TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
399 TestParams { a: 254, b: to_negative(3, bit_size), result: 1 },
400 ];
401 evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
402
403 let bit_size = IntegerBitSize::U128;
404
405 let test_ops = vec![
406 TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
407 TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
408 ];
409 evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
410
411 assert_eq!(
413 evaluate_binary_int_op(
414 &BinaryIntOp::Sub,
415 MemoryValue::<FieldElement>::U8(1),
416 MemoryValue::<FieldElement>::U16(1),
417 IntegerBitSize::U8
418 ),
419 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
420 );
421 assert_eq!(
422 evaluate_binary_int_op(
423 &BinaryIntOp::Sub,
424 MemoryValue::<FieldElement>::U16(1),
425 MemoryValue::<FieldElement>::U8(1),
426 IntegerBitSize::U8
427 ),
428 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
429 );
430 }
431
432 #[test]
433 fn mul_test() {
434 let bit_size = IntegerBitSize::U8;
435
436 let test_ops = vec![
437 TestParams { a: 5, b: 3, result: 15 },
438 TestParams { a: 5, b: 100, result: 244 },
439 TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
440 TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
441 TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
442 ];
443
444 evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
445
446 let bit_size = IntegerBitSize::U64;
447 let a = 2_u128.pow(bit_size.into()) - 1;
448 let b = 3;
449
450 assert_eq!(evaluate_u128(&BinaryIntOp::Mul, a, b, bit_size), a - 2);
452
453 let bit_size = IntegerBitSize::U128;
454
455 let test_ops = vec![
456 TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
457 TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
458 TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
459 ];
460
461 evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
462
463 assert_eq!(
465 evaluate_binary_int_op(
466 &BinaryIntOp::Mul,
467 MemoryValue::<FieldElement>::U8(1),
468 MemoryValue::<FieldElement>::U16(1),
469 IntegerBitSize::U8
470 ),
471 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
472 );
473 assert_eq!(
474 evaluate_binary_int_op(
475 &BinaryIntOp::Mul,
476 MemoryValue::<FieldElement>::U16(1),
477 MemoryValue::<FieldElement>::U8(1),
478 IntegerBitSize::U8
479 ),
480 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
481 );
482 }
483
484 #[test]
485 fn div_test() {
486 let bit_size = IntegerBitSize::U8;
487
488 let test_ops =
489 vec![TestParams { a: 5, b: 3, result: 1 }, TestParams { a: 5, b: 10, result: 0 }];
490
491 evaluate_int_ops(test_ops, BinaryIntOp::Div, bit_size);
492
493 assert_eq!(
495 evaluate_binary_int_op(
496 &BinaryIntOp::Div,
497 MemoryValue::<FieldElement>::U8(1),
498 MemoryValue::<FieldElement>::U16(1),
499 IntegerBitSize::U8
500 ),
501 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
502 );
503 assert_eq!(
504 evaluate_binary_int_op(
505 &BinaryIntOp::Div,
506 MemoryValue::<FieldElement>::U16(1),
507 MemoryValue::<FieldElement>::U8(1),
508 IntegerBitSize::U8
509 ),
510 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
511 );
512
513 assert_eq!(
515 evaluate_binary_int_op(
516 &BinaryIntOp::Div,
517 MemoryValue::<FieldElement>::U8(1),
518 MemoryValue::<FieldElement>::U8(0),
519 IntegerBitSize::U8
520 ),
521 Err(BrilligArithmeticError::DivisionByZero)
522 );
523 }
524
525 #[test]
526 fn shl_test() {
527 let bit_size = IntegerBitSize::U8;
528
529 let test_ops =
530 vec![TestParams { a: 1, b: 7, result: 128 }, TestParams { a: 5, b: 7, result: 128 }];
531
532 evaluate_int_ops(test_ops, BinaryIntOp::Shl, bit_size);
533 assert_eq!(
535 evaluate_binary_int_op(
536 &BinaryIntOp::Shl,
537 MemoryValue::<FieldElement>::U8(1u8),
538 MemoryValue::<FieldElement>::U8(8u8),
539 IntegerBitSize::U8
540 ),
541 Ok(MemoryValue::<FieldElement>::U8(0u8))
542 );
543 }
544
545 #[test]
546 fn shr_test() {
547 let bit_size = IntegerBitSize::U8;
548
549 let test_ops =
550 vec![TestParams { a: 1, b: 0, result: 1 }, TestParams { a: 5, b: 1, result: 2 }];
551
552 evaluate_int_ops(test_ops, BinaryIntOp::Shr, bit_size);
553 assert_eq!(
555 evaluate_binary_int_op(
556 &BinaryIntOp::Shr,
557 MemoryValue::<FieldElement>::U8(1u8),
558 MemoryValue::<FieldElement>::U8(8u8),
559 IntegerBitSize::U8
560 ),
561 Ok(MemoryValue::<FieldElement>::U8(0u8))
562 );
563 }
564
565 fn reference_shift(op: &BinaryIntOp, lhs: u128, rhs: u128, bit_width: u32) -> u128 {
568 let mask = if bit_width == 128 { u128::MAX } else { (1u128 << bit_width) - 1 };
569 let lhs = lhs & mask;
570 let rhs = rhs & mask;
571 if rhs >= u128::from(bit_width) {
572 return 0;
573 }
574 let result = match op {
575 BinaryIntOp::Shl => lhs << rhs as u32,
576 BinaryIntOp::Shr => lhs >> rhs as u32,
577 _ => unreachable!(),
578 };
579 result & mask
580 }
581
582 proptest::proptest! {
583 #[test]
584 fn shift_u8_fuzz(lhs in 0u128..=u128::from(u8::MAX), rhs in 0u128..=u128::from(u8::MAX)) {
585 let bit_size = IntegerBitSize::U8;
586 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
587 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
588 let expected = reference_shift(&op, lhs, rhs, 8);
589 proptest::prop_assert_eq!(actual, expected);
590 }
591 }
592
593 #[test]
594 fn shift_u16_fuzz(lhs in 0u128..=u128::from(u16::MAX), rhs in 0u128..=u128::from(u16::MAX)) {
595 let bit_size = IntegerBitSize::U16;
596 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
597 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
598 let expected = reference_shift(&op, lhs, rhs, 16);
599 proptest::prop_assert_eq!(actual, expected);
600 }
601 }
602
603 #[test]
604 fn shift_u32_fuzz(lhs in 0u128..=u128::from(u32::MAX), rhs in 0u128..=u128::from(u32::MAX)) {
605 let bit_size = IntegerBitSize::U32;
606 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
607 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
608 let expected = reference_shift(&op, lhs, rhs, 32);
609 proptest::prop_assert_eq!(actual, expected);
610 }
611 }
612
613 #[test]
614 fn shift_u64_fuzz(lhs in 0u128..=u128::from(u64::MAX), rhs in 0u128..=u128::from(u64::MAX)) {
615 let bit_size = IntegerBitSize::U64;
616 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
617 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
618 let expected = reference_shift(&op, lhs, rhs, 64);
619 proptest::prop_assert_eq!(actual, expected);
620 }
621 }
622
623 #[test]
624 fn shift_u128_fuzz(lhs: u128, rhs: u128) {
625 let bit_size = IntegerBitSize::U128;
626 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
627 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
628 let expected = reference_shift(&op, lhs, rhs, 128);
629 proptest::prop_assert_eq!(actual, expected);
630 }
631 }
632 }
633
634 #[test]
635 fn comparison_ops_test() {
636 let bit_size = IntegerBitSize::U8;
637
638 let test_ops = vec![
640 TestParams { a: 5, b: 5, result: 1 },
641 TestParams { a: 10, b: 5, result: 0 },
642 TestParams { a: 0, b: 0, result: 1 },
643 ];
644 evaluate_int_ops(test_ops, BinaryIntOp::Equals, bit_size);
645
646 let test_ops = vec![
648 TestParams { a: 4, b: 5, result: 1 },
649 TestParams { a: 5, b: 4, result: 0 },
650 TestParams { a: 5, b: 5, result: 0 },
651 ];
652 evaluate_int_ops(test_ops, BinaryIntOp::LessThan, bit_size);
653
654 let test_ops = vec![
656 TestParams { a: 4, b: 5, result: 1 },
657 TestParams { a: 5, b: 4, result: 0 },
658 TestParams { a: 5, b: 5, result: 1 },
659 ];
660 evaluate_int_ops(test_ops, BinaryIntOp::LessThanEquals, bit_size);
661
662 assert_eq!(
664 evaluate_binary_int_op(
665 &BinaryIntOp::Equals,
666 MemoryValue::<FieldElement>::U8(1),
667 MemoryValue::<FieldElement>::U16(1),
668 IntegerBitSize::U8
669 ),
670 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
671 );
672 assert_eq!(
673 evaluate_binary_int_op(
674 &BinaryIntOp::LessThan,
675 MemoryValue::<FieldElement>::U16(1),
676 MemoryValue::<FieldElement>::U8(1),
677 IntegerBitSize::U8
678 ),
679 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
680 );
681 assert_eq!(
682 evaluate_binary_int_op(
683 &BinaryIntOp::LessThanEquals,
684 MemoryValue::<FieldElement>::U16(1),
685 MemoryValue::<FieldElement>::U8(1),
686 IntegerBitSize::U8
687 ),
688 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
689 );
690 }
691}
692
693#[cfg(test)]
694mod field_ops {
695 use super::*;
696 use acir::{AcirField, FieldElement};
697
698 struct TestParams {
699 a: FieldElement,
700 b: FieldElement,
701 result: FieldElement,
702 }
703
704 fn evaluate_field_u128(op: &BinaryFieldOp, a: FieldElement, b: FieldElement) -> FieldElement {
705 let result_value: MemoryValue<FieldElement> =
706 evaluate_binary_field_op(op, MemoryValue::new_field(a), MemoryValue::new_field(b))
707 .unwrap();
708 result_value.to_field()
710 }
711
712 fn evaluate_field_ops(test_params: Vec<TestParams>, op: BinaryFieldOp) {
713 for test in test_params {
714 assert_eq!(evaluate_field_u128(&op, test.a, test.b), test.result);
715 }
716 }
717
718 #[test]
719 fn add_test() {
720 let test_ops = vec![
721 TestParams { a: 1u32.into(), b: 2u32.into(), result: 3u32.into() },
722 TestParams { a: 5u32.into(), b: 10u32.into(), result: 15u32.into() },
723 TestParams { a: 250u32.into(), b: 10u32.into(), result: 260u32.into() },
724 ];
725 evaluate_field_ops(test_ops, BinaryFieldOp::Add);
726
727 assert_eq!(
729 evaluate_binary_field_op(
730 &BinaryFieldOp::Add,
731 MemoryValue::new_field(FieldElement::from(1u32)),
732 MemoryValue::<FieldElement>::U16(2),
733 ),
734 Err(BrilligArithmeticError::MismatchedRhsBitSize {
735 rhs_bit_size: 16,
736 op_bit_size: 254
737 })
738 );
739
740 assert_eq!(
741 evaluate_binary_field_op(
742 &BinaryFieldOp::Add,
743 MemoryValue::<FieldElement>::U16(1),
744 MemoryValue::new_field(FieldElement::from(2u32)),
745 ),
746 Err(BrilligArithmeticError::MismatchedLhsBitSize {
747 lhs_bit_size: 16,
748 op_bit_size: 254
749 })
750 );
751 }
752
753 #[test]
754 fn sub_test() {
755 let test_ops = vec![
756 TestParams { a: 5u32.into(), b: 3u32.into(), result: 2u32.into() },
757 TestParams { a: 2u32.into(), b: 10u32.into(), result: FieldElement::from(-8_i128) },
758 ];
759 evaluate_field_ops(test_ops, BinaryFieldOp::Sub);
760
761 assert_eq!(
763 evaluate_binary_field_op(
764 &BinaryFieldOp::Sub,
765 MemoryValue::new_field(FieldElement::from(1u32)),
766 MemoryValue::<FieldElement>::U16(1),
767 ),
768 Err(BrilligArithmeticError::MismatchedRhsBitSize {
769 rhs_bit_size: 16,
770 op_bit_size: 254
771 })
772 );
773
774 assert_eq!(
775 evaluate_binary_field_op(
776 &BinaryFieldOp::Sub,
777 MemoryValue::<FieldElement>::U16(1),
778 MemoryValue::new_field(FieldElement::from(1u32)),
779 ),
780 Err(BrilligArithmeticError::MismatchedLhsBitSize {
781 lhs_bit_size: 16,
782 op_bit_size: 254
783 })
784 );
785 }
786
787 #[test]
788 fn mul_test() {
789 let test_ops = vec![
790 TestParams { a: 2u32.into(), b: 3u32.into(), result: 6u32.into() },
791 TestParams { a: 10u32.into(), b: 25u32.into(), result: 250u32.into() },
792 ];
793 evaluate_field_ops(test_ops, BinaryFieldOp::Mul);
794
795 assert_eq!(
797 evaluate_binary_field_op(
798 &BinaryFieldOp::Mul,
799 MemoryValue::new_field(FieldElement::from(1u32)),
800 MemoryValue::<FieldElement>::U16(1),
801 ),
802 Err(BrilligArithmeticError::MismatchedRhsBitSize {
803 rhs_bit_size: 16,
804 op_bit_size: 254
805 })
806 );
807
808 assert_eq!(
809 evaluate_binary_field_op(
810 &BinaryFieldOp::Mul,
811 MemoryValue::<FieldElement>::U16(1),
812 MemoryValue::new_field(FieldElement::from(1u32)),
813 ),
814 Err(BrilligArithmeticError::MismatchedLhsBitSize {
815 lhs_bit_size: 16,
816 op_bit_size: 254
817 })
818 );
819 }
820
821 #[test]
822 fn div_test() {
823 let test_ops = vec![
824 TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
825 TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
826 TestParams {
827 a: 10u32.into(),
828 b: FieldElement::from(-1_i128),
829 result: FieldElement::from(-10_i128),
830 },
831 TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
832 TestParams {
834 a: 10u32.into(),
835 b: 20u32.into(),
836 result: FieldElement::try_from_str(
837 "10944121435919637611123202872628637544274182200208017171849102093287904247809",
838 )
839 .unwrap(),
840 },
841 TestParams {
843 a: 100u32.into(),
844 b: 7u32.into(),
845 result: FieldElement::try_from_str(
846 "6253783677668364349213258784359221453870961257261724098199486910450230998762",
847 )
848 .unwrap(),
849 },
850 ];
851 evaluate_field_ops(test_ops, BinaryFieldOp::Div);
852
853 assert_eq!(
855 evaluate_binary_field_op(
856 &BinaryFieldOp::Div,
857 MemoryValue::new_field(FieldElement::from(1u128)),
858 MemoryValue::new_field(FieldElement::zero()),
859 ),
860 Err(BrilligArithmeticError::DivisionByZero)
861 );
862
863 assert_eq!(
865 evaluate_binary_field_op(
866 &BinaryFieldOp::Div,
867 MemoryValue::new_field(FieldElement::from(1u32)),
868 MemoryValue::<FieldElement>::U16(1),
869 ),
870 Err(BrilligArithmeticError::MismatchedRhsBitSize {
871 rhs_bit_size: 16,
872 op_bit_size: 254
873 })
874 );
875
876 assert_eq!(
877 evaluate_binary_field_op(
878 &BinaryFieldOp::Div,
879 MemoryValue::<FieldElement>::U16(1),
880 MemoryValue::new_field(FieldElement::from(1u32)),
881 ),
882 Err(BrilligArithmeticError::MismatchedLhsBitSize {
883 lhs_bit_size: 16,
884 op_bit_size: 254
885 })
886 );
887 }
888
889 #[test]
890 fn integer_div_test() {
891 let test_ops = vec![
892 TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
893 TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
894 TestParams { a: 10u32.into(), b: FieldElement::from(-1_i128), result: 0u32.into() },
896 TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
897 TestParams { a: 10u32.into(), b: 20u32.into(), result: 0u32.into() },
898 TestParams { a: 100u32.into(), b: 7u32.into(), result: 14u32.into() },
900 ];
901 evaluate_field_ops(test_ops, BinaryFieldOp::IntegerDiv);
902
903 assert_eq!(
905 evaluate_binary_field_op(
906 &BinaryFieldOp::IntegerDiv,
907 MemoryValue::new_field(FieldElement::from(1u128)),
908 MemoryValue::new_field(FieldElement::zero()),
909 ),
910 Err(BrilligArithmeticError::DivisionByZero)
911 );
912
913 assert_eq!(
915 evaluate_binary_field_op(
916 &BinaryFieldOp::IntegerDiv,
917 MemoryValue::new_field(FieldElement::from(1u32)),
918 MemoryValue::<FieldElement>::U16(1),
919 ),
920 Err(BrilligArithmeticError::MismatchedRhsBitSize {
921 rhs_bit_size: 16,
922 op_bit_size: 254
923 })
924 );
925
926 assert_eq!(
927 evaluate_binary_field_op(
928 &BinaryFieldOp::IntegerDiv,
929 MemoryValue::<FieldElement>::U16(1),
930 MemoryValue::new_field(FieldElement::from(1u32)),
931 ),
932 Err(BrilligArithmeticError::MismatchedLhsBitSize {
933 lhs_bit_size: 16,
934 op_bit_size: 254
935 })
936 );
937 }
938
939 #[test]
940 fn comparison_ops_test() {
941 let test_ops = vec![
943 TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
944 TestParams { a: 10u32.into(), b: 5u32.into(), result: 0u32.into() },
945 TestParams { a: 0u32.into(), b: 0u32.into(), result: 1u32.into() },
946 ];
947 evaluate_field_ops(test_ops, BinaryFieldOp::Equals);
948
949 let test_ops = vec![
951 TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
952 TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
953 TestParams { a: 5u32.into(), b: 5u32.into(), result: 0u32.into() },
954 ];
955 evaluate_field_ops(test_ops, BinaryFieldOp::LessThan);
956
957 let test_ops = vec![
959 TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
960 TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
961 TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
962 ];
963 evaluate_field_ops(test_ops, BinaryFieldOp::LessThanEquals);
964
965 assert_eq!(
967 evaluate_binary_field_op(
968 &BinaryFieldOp::Equals,
969 MemoryValue::new_field(1u32.into()),
970 MemoryValue::<FieldElement>::U16(1),
971 ),
972 Err(BrilligArithmeticError::MismatchedRhsBitSize {
973 rhs_bit_size: 16,
974 op_bit_size: 254
975 })
976 );
977
978 assert_eq!(
979 evaluate_binary_field_op(
980 &BinaryFieldOp::LessThan,
981 MemoryValue::<FieldElement>::U16(1),
982 MemoryValue::new_field(1u32.into()),
983 ),
984 Err(BrilligArithmeticError::MismatchedLhsBitSize {
985 lhs_bit_size: 16,
986 op_bit_size: 254
987 })
988 );
989
990 assert_eq!(
991 evaluate_binary_field_op(
992 &BinaryFieldOp::LessThanEquals,
993 MemoryValue::<FieldElement>::U16(1),
994 MemoryValue::new_field(1u32.into()),
995 ),
996 Err(BrilligArithmeticError::MismatchedLhsBitSize {
997 lhs_bit_size: 16,
998 op_bit_size: 254
999 })
1000 );
1001 }
1002}