1use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
4
5use acir::AcirField;
6use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize};
7use num_bigint::BigUint;
8use num_traits::{CheckedDiv, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
9
10use crate::memory::{MemoryTypeError, MemoryValue};
11
12#[derive(Debug, PartialEq, thiserror::Error)]
13pub(crate) enum BrilligArithmeticError {
14 #[error("Bit size for lhs {lhs_bit_size} does not match op bit size {op_bit_size}")]
15 MismatchedLhsBitSize { lhs_bit_size: u32, op_bit_size: u32 },
16 #[error("Bit size for rhs {rhs_bit_size} does not match op bit size {op_bit_size}")]
17 MismatchedRhsBitSize { rhs_bit_size: u32, op_bit_size: u32 },
18 #[error("Attempted to divide by zero")]
19 DivisionByZero,
20}
21
22pub(crate) fn evaluate_binary_field_op<F: AcirField>(
24 op: &BinaryFieldOp,
25 lhs: MemoryValue<F>,
26 rhs: MemoryValue<F>,
27) -> Result<MemoryValue<F>, BrilligArithmeticError> {
28 let expect_field = |value: MemoryValue<F>, make_err: fn(u32, u32) -> BrilligArithmeticError| {
29 value.expect_field().map_err(|err| {
30 if let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err {
31 make_err(value_bit_size, expected_bit_size)
32 } else {
33 unreachable!("MemoryTypeError NotInteger is only produced by to_u128")
34 }
35 })
36 };
37 let a = expect_field(lhs, |vbs, ebs| BrilligArithmeticError::MismatchedLhsBitSize {
38 lhs_bit_size: vbs,
39 op_bit_size: ebs,
40 })?;
41 let b = expect_field(rhs, |vbs, ebs| BrilligArithmeticError::MismatchedRhsBitSize {
42 rhs_bit_size: vbs,
43 op_bit_size: ebs,
44 })?;
45
46 Ok(match op {
47 BinaryFieldOp::Add => MemoryValue::new_field(a + b),
49 BinaryFieldOp::Sub => MemoryValue::new_field(a - b),
50 BinaryFieldOp::Mul => MemoryValue::new_field(a * b),
51 BinaryFieldOp::Div => {
52 if b.is_zero() {
53 return Err(BrilligArithmeticError::DivisionByZero);
54 } else if b.is_one() {
55 MemoryValue::new_field(a)
56 } else if b == -F::one() {
57 MemoryValue::new_field(-a)
58 } else {
59 MemoryValue::new_field(a / b)
60 }
61 }
62 BinaryFieldOp::IntegerDiv => {
63 if b.is_zero() {
72 return Err(BrilligArithmeticError::DivisionByZero);
73 } else {
74 let a_big = BigUint::from_bytes_be(&a.to_be_bytes());
75 let b_big = BigUint::from_bytes_be(&b.to_be_bytes());
76
77 let result = a_big / b_big;
78 MemoryValue::new_field(F::from_be_bytes_reduce(&result.to_bytes_be()))
79 }
80 }
81 BinaryFieldOp::Equals => (a == b).into(),
82 BinaryFieldOp::LessThan => (a < b).into(),
83 BinaryFieldOp::LessThanEquals => (a <= b).into(),
84 })
85}
86
87pub(crate) fn evaluate_binary_int_op<F: AcirField>(
89 op: &BinaryIntOp,
90 lhs: MemoryValue<F>,
91 rhs: MemoryValue<F>,
92 bit_size: IntegerBitSize,
93) -> Result<MemoryValue<F>, BrilligArithmeticError> {
94 match op {
95 BinaryIntOp::Add
96 | BinaryIntOp::Sub
97 | BinaryIntOp::Mul
98 | BinaryIntOp::Div
99 | BinaryIntOp::And
100 | BinaryIntOp::Or
101 | BinaryIntOp::Xor => match (lhs, rhs, bit_size) {
102 (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
103 evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1)
104 }
105 (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
106 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8)
107 }
108 (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
109 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16)
110 }
111 (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
112 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32)
113 }
114 (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
115 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64)
116 }
117 (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
118 evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128)
119 }
120 (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
121 Err(BrilligArithmeticError::MismatchedLhsBitSize {
122 lhs_bit_size: lhs.bit_size().to_u32::<F>(),
123 op_bit_size: bit_size.into(),
124 })
125 }
126 (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
127 Err(BrilligArithmeticError::MismatchedRhsBitSize {
128 rhs_bit_size: rhs.bit_size().to_u32::<F>(),
129 op_bit_size: bit_size.into(),
130 })
131 }
132 _ => unreachable!("Invalid arguments are covered by the two arms above."),
133 },
134
135 BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
136 match (lhs, rhs, bit_size) {
137 (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
138 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
139 }
140 (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
141 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
142 }
143 (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
144 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
145 }
146 (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
147 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
148 }
149 (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
150 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
151 }
152 (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
153 Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
154 }
155 (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
156 Err(BrilligArithmeticError::MismatchedLhsBitSize {
157 lhs_bit_size: lhs.bit_size().to_u32::<F>(),
158 op_bit_size: bit_size.into(),
159 })
160 }
161 (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
162 Err(BrilligArithmeticError::MismatchedRhsBitSize {
163 rhs_bit_size: rhs.bit_size().to_u32::<F>(),
164 op_bit_size: bit_size.into(),
165 })
166 }
167 _ => unreachable!("Invalid arguments are covered by the two arms above."),
168 }
169 }
170
171 BinaryIntOp::Shl | BinaryIntOp::Shr => match (lhs, rhs, bit_size) {
172 (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
173 if rhs {
174 Ok(MemoryValue::U1(false))
175 } else {
176 Ok(MemoryValue::U1(lhs))
177 }
178 }
179 (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
180 Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
181 }
182 (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
183 Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
184 }
185 (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
186 Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
187 }
188 (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
189 Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
190 }
191 (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
192 Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs)?))
193 }
194 (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
195 Err(BrilligArithmeticError::MismatchedLhsBitSize {
196 lhs_bit_size: lhs.bit_size().to_u32::<F>(),
197 op_bit_size: bit_size.into(),
198 })
199 }
200 (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
201 Err(BrilligArithmeticError::MismatchedRhsBitSize {
202 rhs_bit_size: rhs.bit_size().to_u32::<F>(),
203 op_bit_size: bit_size.into(),
204 })
205 }
206 _ => unreachable!("Invalid arguments are covered by the two arms above."),
207 },
208 }
209}
210
211fn evaluate_binary_int_op_u1(
221 op: &BinaryIntOp,
222 lhs: bool,
223 rhs: bool,
224) -> Result<bool, BrilligArithmeticError> {
225 let result = match op {
226 BinaryIntOp::Equals => lhs == rhs,
227 BinaryIntOp::LessThan => !lhs & rhs,
228 BinaryIntOp::LessThanEquals => lhs <= rhs,
229 BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs,
230 BinaryIntOp::Or => lhs | rhs,
231 BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
232 BinaryIntOp::Div => {
233 if !rhs {
234 return Err(BrilligArithmeticError::DivisionByZero);
235 } else {
236 lhs
237 }
238 }
239 _ => unreachable!("Operator not handled by this function: {op:?}"),
240 };
241 Ok(result)
242}
243
244fn evaluate_binary_int_op_cmp<T: Ord + PartialEq>(op: &BinaryIntOp, lhs: T, rhs: T) -> bool {
250 match op {
251 BinaryIntOp::Equals => lhs == rhs,
252 BinaryIntOp::LessThan => lhs < rhs,
253 BinaryIntOp::LessThanEquals => lhs <= rhs,
254 _ => unreachable!("Operator not handled by this function: {op:?}"),
255 }
256}
257
258fn evaluate_binary_int_op_shifts<T: ToPrimitive + Zero + Shl<Output = T> + Shr<Output = T>>(
267 op: &BinaryIntOp,
268 lhs: T,
269 rhs: T,
270) -> Result<T, BrilligArithmeticError> {
271 let bit_size = (size_of::<T>() * 8) as u128;
272 let rhs_val = rhs.to_u128().unwrap();
273 if rhs_val >= bit_size {
274 return Ok(T::zero());
275 }
276 match op {
277 BinaryIntOp::Shl => Ok(lhs << rhs),
278 BinaryIntOp::Shr => Ok(lhs >> rhs),
279 _ => unreachable!("Operator not handled by this function: {op:?}"),
280 }
281}
282
283fn evaluate_binary_int_op_arith<
293 T: WrappingAdd
294 + WrappingSub
295 + WrappingMul
296 + CheckedDiv
297 + BitAnd<Output = T>
298 + BitOr<Output = T>
299 + BitXor<Output = T>,
300>(
301 op: &BinaryIntOp,
302 lhs: T,
303 rhs: T,
304) -> Result<T, BrilligArithmeticError> {
305 let result = match op {
306 BinaryIntOp::Add => lhs.wrapping_add(&rhs),
307 BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
308 BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
309 BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
310 BinaryIntOp::And => lhs & rhs,
311 BinaryIntOp::Or => lhs | rhs,
312 BinaryIntOp::Xor => lhs ^ rhs,
313 _ => unreachable!("Operator not handled by this function: {op:?}"),
314 };
315 Ok(result)
316}
317
318#[cfg(test)]
319mod int_ops {
320 use super::*;
321 use acir::{AcirField, FieldElement};
322
323 struct TestParams {
324 a: u128,
325 b: u128,
326 result: u128,
327 }
328
329 fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: IntegerBitSize) -> u128 {
330 let result_value: MemoryValue<FieldElement> = evaluate_binary_int_op(
331 op,
332 MemoryValue::new_integer(a, bit_size),
333 MemoryValue::new_integer(b, bit_size),
334 bit_size,
335 )
336 .unwrap();
337 result_value.to_field().to_u128()
339 }
340
341 fn to_negative(a: u128, bit_size: IntegerBitSize) -> u128 {
342 assert!(a > 0);
343 if bit_size == IntegerBitSize::U128 {
344 0_u128.wrapping_sub(a)
345 } else {
346 let two_pow = 2_u128.pow(bit_size.into());
347 two_pow - a
348 }
349 }
350
351 fn evaluate_int_ops(test_params: Vec<TestParams>, op: BinaryIntOp, bit_size: IntegerBitSize) {
352 for test in test_params {
353 assert_eq!(evaluate_u128(&op, test.a, test.b, bit_size), test.result);
354 }
355 }
356
357 #[test]
358 fn add_test() {
359 let bit_size = IntegerBitSize::U8;
360
361 let test_ops = vec![
362 TestParams { a: 50, b: 100, result: 150 },
363 TestParams { a: 250, b: 10, result: 4 },
364 TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
365 TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
366 TestParams { a: 5, b: to_negative(6, bit_size), result: to_negative(1, bit_size) },
367 ];
368 evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
369
370 let bit_size = IntegerBitSize::U128;
371 let test_ops = vec![
372 TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
373 TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
374 ];
375
376 evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
377
378 assert_eq!(
380 evaluate_binary_int_op(
381 &BinaryIntOp::Add,
382 MemoryValue::<FieldElement>::U8(1),
383 MemoryValue::<FieldElement>::U16(2),
384 IntegerBitSize::U8
385 ),
386 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
387 );
388 assert_eq!(
389 evaluate_binary_int_op(
390 &BinaryIntOp::Add,
391 MemoryValue::<FieldElement>::U16(2),
392 MemoryValue::<FieldElement>::U8(1),
393 IntegerBitSize::U8
394 ),
395 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
396 );
397 }
398
399 #[test]
400 fn sub_test() {
401 let bit_size = IntegerBitSize::U8;
402
403 let test_ops = vec![
404 TestParams { a: 50, b: 30, result: 20 },
405 TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
406 TestParams { a: 5, b: to_negative(3, bit_size), result: 8 },
407 TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
408 TestParams { a: 254, b: to_negative(3, bit_size), result: 1 },
409 ];
410 evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
411
412 let bit_size = IntegerBitSize::U128;
413
414 let test_ops = vec![
415 TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
416 TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
417 ];
418 evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
419
420 assert_eq!(
422 evaluate_binary_int_op(
423 &BinaryIntOp::Sub,
424 MemoryValue::<FieldElement>::U8(1),
425 MemoryValue::<FieldElement>::U16(1),
426 IntegerBitSize::U8
427 ),
428 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
429 );
430 assert_eq!(
431 evaluate_binary_int_op(
432 &BinaryIntOp::Sub,
433 MemoryValue::<FieldElement>::U16(1),
434 MemoryValue::<FieldElement>::U8(1),
435 IntegerBitSize::U8
436 ),
437 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
438 );
439 }
440
441 #[test]
442 fn mul_test() {
443 let bit_size = IntegerBitSize::U8;
444
445 let test_ops = vec![
446 TestParams { a: 5, b: 3, result: 15 },
447 TestParams { a: 5, b: 100, result: 244 },
448 TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
449 TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
450 TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
451 ];
452
453 evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
454
455 let bit_size = IntegerBitSize::U64;
456 let a = 2_u128.pow(bit_size.into()) - 1;
457 let b = 3;
458
459 assert_eq!(evaluate_u128(&BinaryIntOp::Mul, a, b, bit_size), a - 2);
461
462 let bit_size = IntegerBitSize::U128;
463
464 let test_ops = vec![
465 TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
466 TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
467 TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
468 ];
469
470 evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
471
472 assert_eq!(
474 evaluate_binary_int_op(
475 &BinaryIntOp::Mul,
476 MemoryValue::<FieldElement>::U8(1),
477 MemoryValue::<FieldElement>::U16(1),
478 IntegerBitSize::U8
479 ),
480 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
481 );
482 assert_eq!(
483 evaluate_binary_int_op(
484 &BinaryIntOp::Mul,
485 MemoryValue::<FieldElement>::U16(1),
486 MemoryValue::<FieldElement>::U8(1),
487 IntegerBitSize::U8
488 ),
489 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
490 );
491 }
492
493 #[test]
494 fn div_test() {
495 let bit_size = IntegerBitSize::U8;
496
497 let test_ops =
498 vec![TestParams { a: 5, b: 3, result: 1 }, TestParams { a: 5, b: 10, result: 0 }];
499
500 evaluate_int_ops(test_ops, BinaryIntOp::Div, bit_size);
501
502 assert_eq!(
504 evaluate_binary_int_op(
505 &BinaryIntOp::Div,
506 MemoryValue::<FieldElement>::U8(1),
507 MemoryValue::<FieldElement>::U16(1),
508 IntegerBitSize::U8
509 ),
510 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
511 );
512 assert_eq!(
513 evaluate_binary_int_op(
514 &BinaryIntOp::Div,
515 MemoryValue::<FieldElement>::U16(1),
516 MemoryValue::<FieldElement>::U8(1),
517 IntegerBitSize::U8
518 ),
519 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
520 );
521
522 assert_eq!(
524 evaluate_binary_int_op(
525 &BinaryIntOp::Div,
526 MemoryValue::<FieldElement>::U8(1),
527 MemoryValue::<FieldElement>::U8(0),
528 IntegerBitSize::U8
529 ),
530 Err(BrilligArithmeticError::DivisionByZero)
531 );
532 }
533
534 #[test]
535 fn shl_test() {
536 let bit_size = IntegerBitSize::U8;
537
538 let test_ops =
539 vec![TestParams { a: 1, b: 7, result: 128 }, TestParams { a: 5, b: 7, result: 128 }];
540
541 evaluate_int_ops(test_ops, BinaryIntOp::Shl, bit_size);
542 assert_eq!(
544 evaluate_binary_int_op(
545 &BinaryIntOp::Shl,
546 MemoryValue::<FieldElement>::U8(1u8),
547 MemoryValue::<FieldElement>::U8(8u8),
548 IntegerBitSize::U8
549 ),
550 Ok(MemoryValue::<FieldElement>::U8(0u8))
551 );
552 assert_eq!(
554 evaluate_binary_int_op(
555 &BinaryIntOp::Shr,
556 MemoryValue::<FieldElement>::U16(1),
557 MemoryValue::<FieldElement>::U8(1),
558 IntegerBitSize::U8
559 ),
560 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
561 );
562 assert_eq!(
563 evaluate_binary_int_op(
564 &BinaryIntOp::Shr,
565 MemoryValue::<FieldElement>::U8(1),
566 MemoryValue::<FieldElement>::U16(1),
567 IntegerBitSize::U8
568 ),
569 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
570 );
571 }
572
573 #[test]
574 fn shr_test() {
575 let bit_size = IntegerBitSize::U8;
576
577 let test_ops =
578 vec![TestParams { a: 1, b: 0, result: 1 }, TestParams { a: 5, b: 1, result: 2 }];
579
580 evaluate_int_ops(test_ops, BinaryIntOp::Shr, bit_size);
581 assert_eq!(
583 evaluate_binary_int_op(
584 &BinaryIntOp::Shr,
585 MemoryValue::<FieldElement>::U8(1u8),
586 MemoryValue::<FieldElement>::U8(8u8),
587 IntegerBitSize::U8
588 ),
589 Ok(MemoryValue::<FieldElement>::U8(0u8))
590 );
591 assert_eq!(
593 evaluate_binary_int_op(
594 &BinaryIntOp::Shr,
595 MemoryValue::<FieldElement>::U16(1),
596 MemoryValue::<FieldElement>::U8(1),
597 IntegerBitSize::U8
598 ),
599 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
600 );
601 assert_eq!(
602 evaluate_binary_int_op(
603 &BinaryIntOp::Shr,
604 MemoryValue::<FieldElement>::U8(1),
605 MemoryValue::<FieldElement>::U16(1),
606 IntegerBitSize::U8
607 ),
608 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
609 );
610 }
611
612 fn reference_shift(op: &BinaryIntOp, lhs: u128, rhs: u128, bit_width: u32) -> u128 {
615 let mask = if bit_width == 128 { u128::MAX } else { (1u128 << bit_width) - 1 };
616 let lhs = lhs & mask;
617 let rhs = rhs & mask;
618 if rhs >= u128::from(bit_width) {
619 return 0;
620 }
621 let result = match op {
622 BinaryIntOp::Shl => lhs << rhs as u32,
623 BinaryIntOp::Shr => lhs >> rhs as u32,
624 _ => unreachable!(),
625 };
626 result & mask
627 }
628
629 proptest::proptest! {
630 #[test]
631 fn shift_u8_fuzz(lhs in 0u128..=u128::from(u8::MAX), rhs in 0u128..=u128::from(u8::MAX)) {
632 let bit_size = IntegerBitSize::U8;
633 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
634 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
635 let expected = reference_shift(&op, lhs, rhs, 8);
636 proptest::prop_assert_eq!(actual, expected);
637 }
638 }
639
640 #[test]
641 fn shift_u16_fuzz(lhs in 0u128..=u128::from(u16::MAX), rhs in 0u128..=u128::from(u16::MAX)) {
642 let bit_size = IntegerBitSize::U16;
643 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
644 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
645 let expected = reference_shift(&op, lhs, rhs, 16);
646 proptest::prop_assert_eq!(actual, expected);
647 }
648 }
649
650 #[test]
651 fn shift_u32_fuzz(lhs in 0u128..=u128::from(u32::MAX), rhs in 0u128..=u128::from(u32::MAX)) {
652 let bit_size = IntegerBitSize::U32;
653 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
654 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
655 let expected = reference_shift(&op, lhs, rhs, 32);
656 proptest::prop_assert_eq!(actual, expected);
657 }
658 }
659
660 #[test]
661 fn shift_u64_fuzz(lhs in 0u128..=u128::from(u64::MAX), rhs in 0u128..=u128::from(u64::MAX)) {
662 let bit_size = IntegerBitSize::U64;
663 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
664 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
665 let expected = reference_shift(&op, lhs, rhs, 64);
666 proptest::prop_assert_eq!(actual, expected);
667 }
668 }
669
670 #[test]
671 fn shift_u128_fuzz(lhs: u128, rhs: u128) {
672 let bit_size = IntegerBitSize::U128;
673 for op in [BinaryIntOp::Shl, BinaryIntOp::Shr] {
674 let actual = evaluate_u128(&op, lhs, rhs, bit_size);
675 let expected = reference_shift(&op, lhs, rhs, 128);
676 proptest::prop_assert_eq!(actual, expected);
677 }
678 }
679 }
680
681 #[test]
682 fn comparison_ops_test() {
683 let bit_size = IntegerBitSize::U8;
684
685 let test_ops = vec![
687 TestParams { a: 5, b: 5, result: 1 },
688 TestParams { a: 10, b: 5, result: 0 },
689 TestParams { a: 0, b: 0, result: 1 },
690 ];
691 evaluate_int_ops(test_ops, BinaryIntOp::Equals, bit_size);
692
693 let test_ops = vec![
695 TestParams { a: 4, b: 5, result: 1 },
696 TestParams { a: 5, b: 4, result: 0 },
697 TestParams { a: 5, b: 5, result: 0 },
698 ];
699 evaluate_int_ops(test_ops, BinaryIntOp::LessThan, bit_size);
700
701 let test_ops = vec![
703 TestParams { a: 4, b: 5, result: 1 },
704 TestParams { a: 5, b: 4, result: 0 },
705 TestParams { a: 5, b: 5, result: 1 },
706 ];
707 evaluate_int_ops(test_ops, BinaryIntOp::LessThanEquals, bit_size);
708
709 assert_eq!(
711 evaluate_binary_int_op(
712 &BinaryIntOp::Equals,
713 MemoryValue::<FieldElement>::U8(1),
714 MemoryValue::<FieldElement>::U16(1),
715 IntegerBitSize::U8
716 ),
717 Err(BrilligArithmeticError::MismatchedRhsBitSize { rhs_bit_size: 16, op_bit_size: 8 })
718 );
719 assert_eq!(
720 evaluate_binary_int_op(
721 &BinaryIntOp::LessThan,
722 MemoryValue::<FieldElement>::U16(1),
723 MemoryValue::<FieldElement>::U8(1),
724 IntegerBitSize::U8
725 ),
726 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
727 );
728 assert_eq!(
729 evaluate_binary_int_op(
730 &BinaryIntOp::LessThanEquals,
731 MemoryValue::<FieldElement>::U16(1),
732 MemoryValue::<FieldElement>::U8(1),
733 IntegerBitSize::U8
734 ),
735 Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: 16, op_bit_size: 8 })
736 );
737 }
738}
739
740#[cfg(test)]
741mod field_ops {
742 use super::*;
743 use acir::{AcirField, FieldElement};
744
745 struct TestParams {
746 a: FieldElement,
747 b: FieldElement,
748 result: FieldElement,
749 }
750
751 fn evaluate_field_u128(op: &BinaryFieldOp, a: FieldElement, b: FieldElement) -> FieldElement {
752 let result_value: MemoryValue<FieldElement> =
753 evaluate_binary_field_op(op, MemoryValue::new_field(a), MemoryValue::new_field(b))
754 .unwrap();
755 result_value.to_field()
757 }
758
759 fn evaluate_field_ops(test_params: Vec<TestParams>, op: BinaryFieldOp) {
760 for test in test_params {
761 assert_eq!(evaluate_field_u128(&op, test.a, test.b), test.result);
762 }
763 }
764
765 #[test]
766 fn add_test() {
767 let test_ops = vec![
768 TestParams { a: 1u32.into(), b: 2u32.into(), result: 3u32.into() },
769 TestParams { a: 5u32.into(), b: 10u32.into(), result: 15u32.into() },
770 TestParams { a: 250u32.into(), b: 10u32.into(), result: 260u32.into() },
771 ];
772 evaluate_field_ops(test_ops, BinaryFieldOp::Add);
773
774 assert_eq!(
776 evaluate_binary_field_op(
777 &BinaryFieldOp::Add,
778 MemoryValue::new_field(FieldElement::from(1u32)),
779 MemoryValue::<FieldElement>::U16(2),
780 ),
781 Err(BrilligArithmeticError::MismatchedRhsBitSize {
782 rhs_bit_size: 16,
783 op_bit_size: 254
784 })
785 );
786
787 assert_eq!(
788 evaluate_binary_field_op(
789 &BinaryFieldOp::Add,
790 MemoryValue::<FieldElement>::U16(1),
791 MemoryValue::new_field(FieldElement::from(2u32)),
792 ),
793 Err(BrilligArithmeticError::MismatchedLhsBitSize {
794 lhs_bit_size: 16,
795 op_bit_size: 254
796 })
797 );
798 }
799
800 #[test]
801 fn sub_test() {
802 let test_ops = vec![
803 TestParams { a: 5u32.into(), b: 3u32.into(), result: 2u32.into() },
804 TestParams { a: 2u32.into(), b: 10u32.into(), result: FieldElement::from(-8_i128) },
805 ];
806 evaluate_field_ops(test_ops, BinaryFieldOp::Sub);
807
808 assert_eq!(
810 evaluate_binary_field_op(
811 &BinaryFieldOp::Sub,
812 MemoryValue::new_field(FieldElement::from(1u32)),
813 MemoryValue::<FieldElement>::U16(1),
814 ),
815 Err(BrilligArithmeticError::MismatchedRhsBitSize {
816 rhs_bit_size: 16,
817 op_bit_size: 254
818 })
819 );
820
821 assert_eq!(
822 evaluate_binary_field_op(
823 &BinaryFieldOp::Sub,
824 MemoryValue::<FieldElement>::U16(1),
825 MemoryValue::new_field(FieldElement::from(1u32)),
826 ),
827 Err(BrilligArithmeticError::MismatchedLhsBitSize {
828 lhs_bit_size: 16,
829 op_bit_size: 254
830 })
831 );
832 }
833
834 #[test]
835 fn mul_test() {
836 let test_ops = vec![
837 TestParams { a: 2u32.into(), b: 3u32.into(), result: 6u32.into() },
838 TestParams { a: 10u32.into(), b: 25u32.into(), result: 250u32.into() },
839 ];
840 evaluate_field_ops(test_ops, BinaryFieldOp::Mul);
841
842 assert_eq!(
844 evaluate_binary_field_op(
845 &BinaryFieldOp::Mul,
846 MemoryValue::new_field(FieldElement::from(1u32)),
847 MemoryValue::<FieldElement>::U16(1),
848 ),
849 Err(BrilligArithmeticError::MismatchedRhsBitSize {
850 rhs_bit_size: 16,
851 op_bit_size: 254
852 })
853 );
854
855 assert_eq!(
856 evaluate_binary_field_op(
857 &BinaryFieldOp::Mul,
858 MemoryValue::<FieldElement>::U16(1),
859 MemoryValue::new_field(FieldElement::from(1u32)),
860 ),
861 Err(BrilligArithmeticError::MismatchedLhsBitSize {
862 lhs_bit_size: 16,
863 op_bit_size: 254
864 })
865 );
866 }
867
868 #[test]
869 fn div_test() {
870 let test_ops = vec![
871 TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
872 TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
873 TestParams {
874 a: 10u32.into(),
875 b: FieldElement::from(-1_i128),
876 result: FieldElement::from(-10_i128),
877 },
878 TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
879 TestParams {
881 a: 10u32.into(),
882 b: 20u32.into(),
883 result: FieldElement::try_from_str(
884 "10944121435919637611123202872628637544274182200208017171849102093287904247809",
885 )
886 .unwrap(),
887 },
888 TestParams {
890 a: 100u32.into(),
891 b: 7u32.into(),
892 result: FieldElement::try_from_str(
893 "6253783677668364349213258784359221453870961257261724098199486910450230998762",
894 )
895 .unwrap(),
896 },
897 ];
898 evaluate_field_ops(test_ops, BinaryFieldOp::Div);
899
900 assert_eq!(
902 evaluate_binary_field_op(
903 &BinaryFieldOp::Div,
904 MemoryValue::new_field(FieldElement::from(1u128)),
905 MemoryValue::new_field(FieldElement::zero()),
906 ),
907 Err(BrilligArithmeticError::DivisionByZero)
908 );
909
910 assert_eq!(
912 evaluate_binary_field_op(
913 &BinaryFieldOp::Div,
914 MemoryValue::new_field(FieldElement::from(1u32)),
915 MemoryValue::<FieldElement>::U16(1),
916 ),
917 Err(BrilligArithmeticError::MismatchedRhsBitSize {
918 rhs_bit_size: 16,
919 op_bit_size: 254
920 })
921 );
922
923 assert_eq!(
924 evaluate_binary_field_op(
925 &BinaryFieldOp::Div,
926 MemoryValue::<FieldElement>::U16(1),
927 MemoryValue::new_field(FieldElement::from(1u32)),
928 ),
929 Err(BrilligArithmeticError::MismatchedLhsBitSize {
930 lhs_bit_size: 16,
931 op_bit_size: 254
932 })
933 );
934 }
935
936 #[test]
937 fn integer_div_test() {
938 let test_ops = vec![
939 TestParams { a: 10u32.into(), b: 2u32.into(), result: 5u32.into() },
940 TestParams { a: 9u32.into(), b: 3u32.into(), result: 3u32.into() },
941 TestParams { a: 10u32.into(), b: FieldElement::from(-1_i128), result: 0u32.into() },
943 TestParams { a: 10u32.into(), b: 1u32.into(), result: 10u32.into() },
944 TestParams { a: 10u32.into(), b: 20u32.into(), result: 0u32.into() },
945 TestParams { a: 100u32.into(), b: 7u32.into(), result: 14u32.into() },
947 ];
948 evaluate_field_ops(test_ops, BinaryFieldOp::IntegerDiv);
949
950 assert_eq!(
952 evaluate_binary_field_op(
953 &BinaryFieldOp::IntegerDiv,
954 MemoryValue::new_field(FieldElement::from(1u128)),
955 MemoryValue::new_field(FieldElement::zero()),
956 ),
957 Err(BrilligArithmeticError::DivisionByZero)
958 );
959
960 assert_eq!(
962 evaluate_binary_field_op(
963 &BinaryFieldOp::IntegerDiv,
964 MemoryValue::new_field(FieldElement::from(1u32)),
965 MemoryValue::<FieldElement>::U16(1),
966 ),
967 Err(BrilligArithmeticError::MismatchedRhsBitSize {
968 rhs_bit_size: 16,
969 op_bit_size: 254
970 })
971 );
972
973 assert_eq!(
974 evaluate_binary_field_op(
975 &BinaryFieldOp::IntegerDiv,
976 MemoryValue::<FieldElement>::U16(1),
977 MemoryValue::new_field(FieldElement::from(1u32)),
978 ),
979 Err(BrilligArithmeticError::MismatchedLhsBitSize {
980 lhs_bit_size: 16,
981 op_bit_size: 254
982 })
983 );
984 }
985
986 #[test]
987 fn comparison_ops_test() {
988 let test_ops = vec![
990 TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
991 TestParams { a: 10u32.into(), b: 5u32.into(), result: 0u32.into() },
992 TestParams { a: 0u32.into(), b: 0u32.into(), result: 1u32.into() },
993 ];
994 evaluate_field_ops(test_ops, BinaryFieldOp::Equals);
995
996 let test_ops = vec![
998 TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
999 TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
1000 TestParams { a: 5u32.into(), b: 5u32.into(), result: 0u32.into() },
1001 ];
1002 evaluate_field_ops(test_ops, BinaryFieldOp::LessThan);
1003
1004 let test_ops = vec![
1006 TestParams { a: 4u32.into(), b: 5u32.into(), result: 1u32.into() },
1007 TestParams { a: 5u32.into(), b: 4u32.into(), result: 0u32.into() },
1008 TestParams { a: 5u32.into(), b: 5u32.into(), result: 1u32.into() },
1009 ];
1010 evaluate_field_ops(test_ops, BinaryFieldOp::LessThanEquals);
1011
1012 assert_eq!(
1014 evaluate_binary_field_op(
1015 &BinaryFieldOp::Equals,
1016 MemoryValue::new_field(1u32.into()),
1017 MemoryValue::<FieldElement>::U16(1),
1018 ),
1019 Err(BrilligArithmeticError::MismatchedRhsBitSize {
1020 rhs_bit_size: 16,
1021 op_bit_size: 254
1022 })
1023 );
1024
1025 assert_eq!(
1026 evaluate_binary_field_op(
1027 &BinaryFieldOp::LessThan,
1028 MemoryValue::<FieldElement>::U16(1),
1029 MemoryValue::new_field(1u32.into()),
1030 ),
1031 Err(BrilligArithmeticError::MismatchedLhsBitSize {
1032 lhs_bit_size: 16,
1033 op_bit_size: 254
1034 })
1035 );
1036
1037 assert_eq!(
1038 evaluate_binary_field_op(
1039 &BinaryFieldOp::LessThanEquals,
1040 MemoryValue::<FieldElement>::U16(1),
1041 MemoryValue::new_field(1u32.into()),
1042 ),
1043 Err(BrilligArithmeticError::MismatchedLhsBitSize {
1044 lhs_bit_size: 16,
1045 op_bit_size: 254
1046 })
1047 );
1048 }
1049}