use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};
use acir::AcirField;
use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize};
use num_bigint::BigUint;
use num_traits::{CheckedDiv, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
use crate::memory::{MemoryTypeError, MemoryValue};
#[derive(Debug, PartialEq, thiserror::Error)]
pub(crate) enum BrilligArithmeticError {
#[error("Bit size for lhs {lhs_bit_size} does not match op bit size {op_bit_size}")]
MismatchedLhsBitSize { lhs_bit_size: u32, op_bit_size: u32 },
#[error("Bit size for rhs {rhs_bit_size} does not match op bit size {op_bit_size}")]
MismatchedRhsBitSize { rhs_bit_size: u32, op_bit_size: u32 },
#[error("Attempted to shift by {shift_size} bits on a type of bit size {bit_size}")]
BitshiftOverflow { bit_size: u32, shift_size: u128 },
#[error("Attempted to divide by zero")]
DivisionByZero,
}
pub(crate) fn evaluate_binary_field_op<F: AcirField>(
op: &BinaryFieldOp,
lhs: MemoryValue<F>,
rhs: MemoryValue<F>,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
let a = lhs.expect_field().map_err(|err| {
if let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err {
BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
} else {
unreachable!("MemoryTypeError NotInteger is only produced by to_u128")
}
})?;
let b = rhs.expect_field().map_err(|err| {
if let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err {
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
} else {
unreachable!("MemoryTypeError NotInteger is only produced by to_u128")
}
})?;
Ok(match op {
BinaryFieldOp::Add => MemoryValue::new_field(a + b),
BinaryFieldOp::Sub => MemoryValue::new_field(a - b),
BinaryFieldOp::Mul => MemoryValue::new_field(a * b),
BinaryFieldOp::Div => {
if b.is_zero() {
return Err(BrilligArithmeticError::DivisionByZero);
} else if b.is_one() {
MemoryValue::new_field(a)
} else if b == -F::one() {
MemoryValue::new_field(-a)
} else {
MemoryValue::new_field(a / b)
}
}
BinaryFieldOp::IntegerDiv => {
if b.is_zero() {
return Err(BrilligArithmeticError::DivisionByZero);
} else {
let a_big = BigUint::from_bytes_be(&a.to_be_bytes());
let b_big = BigUint::from_bytes_be(&b.to_be_bytes());
let result = a_big / b_big;
MemoryValue::new_field(F::from_be_bytes_reduce(&result.to_bytes_be()))
}
}
BinaryFieldOp::Equals => (a == b).into(),
BinaryFieldOp::LessThan => (a < b).into(),
BinaryFieldOp::LessThanEquals => (a <= b).into(),
})
}
pub(crate) fn evaluate_binary_int_op<F: AcirField>(
op: &BinaryIntOp,
lhs: MemoryValue<F>,
rhs: MemoryValue<F>,
bit_size: IntegerBitSize,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
match op {
BinaryIntOp::Add
| BinaryIntOp::Sub
| BinaryIntOp::Mul
| BinaryIntOp::Div
| BinaryIntOp::And
| BinaryIntOp::Or
| BinaryIntOp::Xor => match (lhs, rhs, bit_size) {
(MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1)
}
(MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8)
}
(MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16)
}
(MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32)
}
(MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64)
}
(MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128)
}
(lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
(_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: rhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
_ => unreachable!("Invalid arguments are covered by the two arms above."),
},
BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
match (lhs, rhs, bit_size) {
(MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
(_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: rhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
_ => unreachable!("Invalid arguments are covered by the two arms above."),
}
}
BinaryIntOp::Shl | BinaryIntOp::Shr => match (lhs, rhs, bit_size) {
(MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
if rhs {
Err(BrilligArithmeticError::BitshiftOverflow { bit_size: 1, shift_size: 1 })
} else {
Ok(MemoryValue::U1(lhs))
}
}
(MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
if rhs < 8 {
Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs)))
} else {
Err(BrilligArithmeticError::BitshiftOverflow {
bit_size: 8,
shift_size: rhs as u128,
})
}
}
(MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
if rhs < 16 {
Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs)))
} else {
Err(BrilligArithmeticError::BitshiftOverflow {
bit_size: 16,
shift_size: rhs as u128,
})
}
}
(MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
if rhs < 32 {
Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs)))
} else {
Err(BrilligArithmeticError::BitshiftOverflow {
bit_size: 32,
shift_size: rhs as u128,
})
}
}
(MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
if rhs < 64 {
Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs)))
} else {
Err(BrilligArithmeticError::BitshiftOverflow {
bit_size: 64,
shift_size: rhs as u128,
})
}
}
(MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
if rhs < 128 {
Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs)))
} else {
Err(BrilligArithmeticError::BitshiftOverflow { bit_size: 128, shift_size: rhs })
}
}
_ => Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
}),
},
}
}
fn evaluate_binary_int_op_u1(
op: &BinaryIntOp,
lhs: bool,
rhs: bool,
) -> Result<bool, BrilligArithmeticError> {
let result = match op {
BinaryIntOp::Equals => lhs == rhs,
BinaryIntOp::LessThan => !lhs & rhs,
BinaryIntOp::LessThanEquals => lhs <= rhs,
BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
BinaryIntOp::Div => {
if !rhs {
return Err(BrilligArithmeticError::DivisionByZero);
} else {
lhs
}
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
};
Ok(result)
}
fn evaluate_binary_int_op_cmp<T: Ord + PartialEq>(op: &BinaryIntOp, lhs: T, rhs: T) -> bool {
match op {
BinaryIntOp::Equals => lhs == rhs,
BinaryIntOp::LessThan => lhs < rhs,
BinaryIntOp::LessThanEquals => lhs <= rhs,
_ => unreachable!("Operator not handled by this function: {op:?}"),
}
}
fn evaluate_binary_int_op_shifts<T: ToPrimitive + Zero + Shl<Output = T> + Shr<Output = T>>(
op: &BinaryIntOp,
lhs: T,
rhs: T,
) -> T {
match op {
BinaryIntOp::Shl => {
let rhs_usize: usize = rhs.to_usize().expect("Could not convert rhs to usize");
if rhs_usize >= 8 * size_of::<T>() { T::zero() } else { lhs << rhs }
}
BinaryIntOp::Shr => {
let rhs_usize: usize = rhs.to_usize().expect("Could not convert rhs to usize");
if rhs_usize >= 8 * size_of::<T>() { T::zero() } else { lhs >> rhs }
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
}
}
fn evaluate_binary_int_op_arith<
T: WrappingAdd
+ WrappingSub
+ WrappingMul
+ CheckedDiv
+ BitAnd<Output = T>
+ BitOr<Output = T>
+ BitXor<Output = T>,
>(
op: &BinaryIntOp,
lhs: T,
rhs: T,
) -> Result<T, BrilligArithmeticError> {
let result = match op {
BinaryIntOp::Add => lhs.wrapping_add(&rhs),
BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
BinaryIntOp::And => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor => lhs ^ rhs,
_ => unreachable!("Operator not handled by this function: {op:?}"),
};
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use acir::{AcirField, FieldElement};
struct TestParams {
a: u128,
b: u128,
result: u128,
}
fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: IntegerBitSize) -> u128 {
let result_value: MemoryValue<FieldElement> = evaluate_binary_int_op(
op,
MemoryValue::new_integer(a, bit_size),
MemoryValue::new_integer(b, bit_size),
bit_size,
)
.unwrap();
result_value.to_field().to_u128()
}
fn to_negative(a: u128, bit_size: IntegerBitSize) -> u128 {
assert!(a > 0);
if bit_size == IntegerBitSize::U128 {
0_u128.wrapping_sub(a)
} else {
let two_pow = 2_u128.pow(bit_size.into());
two_pow - a
}
}
fn evaluate_int_ops(test_params: Vec<TestParams>, op: BinaryIntOp, bit_size: IntegerBitSize) {
for test in test_params {
assert_eq!(evaluate_u128(&op, test.a, test.b, bit_size), test.result);
}
}
#[test]
fn add_test() {
let bit_size = IntegerBitSize::U8;
let test_ops = vec![
TestParams { a: 50, b: 100, result: 150 },
TestParams { a: 250, b: 10, result: 4 },
TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
TestParams { a: 5, b: to_negative(6, bit_size), result: to_negative(1, bit_size) },
];
evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
let bit_size = IntegerBitSize::U128;
let test_ops = vec![
TestParams { a: 5, b: to_negative(3, bit_size), result: 2 },
TestParams { a: to_negative(3, bit_size), b: 1, result: to_negative(2, bit_size) },
];
evaluate_int_ops(test_ops, BinaryIntOp::Add, bit_size);
}
#[test]
fn sub_test() {
let bit_size = IntegerBitSize::U8;
let test_ops = vec![
TestParams { a: 50, b: 30, result: 20 },
TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
TestParams { a: 5, b: to_negative(3, bit_size), result: 8 },
TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
TestParams { a: 254, b: to_negative(3, bit_size), result: 1 },
];
evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
let bit_size = IntegerBitSize::U128;
let test_ops = vec![
TestParams { a: 5, b: 10, result: to_negative(5, bit_size) },
TestParams { a: to_negative(3, bit_size), b: 2, result: to_negative(5, bit_size) },
];
evaluate_int_ops(test_ops, BinaryIntOp::Sub, bit_size);
}
#[test]
fn mul_test() {
let bit_size = IntegerBitSize::U8;
let test_ops = vec![
TestParams { a: 5, b: 3, result: 15 },
TestParams { a: 5, b: 100, result: 244 },
TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
];
evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
let bit_size = IntegerBitSize::U64;
let a = 2_u128.pow(bit_size.into()) - 1;
let b = 3;
assert_eq!(evaluate_u128(&BinaryIntOp::Mul, a, b, bit_size), a - 2);
let bit_size = IntegerBitSize::U128;
let test_ops = vec![
TestParams { a: to_negative(1, bit_size), b: to_negative(5, bit_size), result: 5 },
TestParams { a: to_negative(1, bit_size), b: 5, result: to_negative(5, bit_size) },
TestParams { a: to_negative(2, bit_size), b: 7, result: to_negative(14, bit_size) },
];
evaluate_int_ops(test_ops, BinaryIntOp::Mul, bit_size);
}
#[test]
fn div_test() {
let bit_size = IntegerBitSize::U8;
let test_ops =
vec![TestParams { a: 5, b: 3, result: 1 }, TestParams { a: 5, b: 10, result: 0 }];
evaluate_int_ops(test_ops, BinaryIntOp::Div, bit_size);
}
#[test]
fn shl_test() {
let bit_size = IntegerBitSize::U8;
let test_ops =
vec![TestParams { a: 1, b: 7, result: 128 }, TestParams { a: 5, b: 7, result: 128 }];
evaluate_int_ops(test_ops, BinaryIntOp::Shl, bit_size);
assert_eq!(
evaluate_binary_int_op(
&BinaryIntOp::Shl,
MemoryValue::<FieldElement>::U8(1u8),
MemoryValue::<FieldElement>::U8(8u8),
IntegerBitSize::U8
),
Err(BrilligArithmeticError::BitshiftOverflow { bit_size: 8, shift_size: 8 })
);
}
#[test]
fn shr_test() {
let bit_size = IntegerBitSize::U8;
let test_ops =
vec![TestParams { a: 1, b: 0, result: 1 }, TestParams { a: 5, b: 1, result: 2 }];
evaluate_int_ops(test_ops, BinaryIntOp::Shr, bit_size);
assert_eq!(
evaluate_binary_int_op(
&BinaryIntOp::Shr,
MemoryValue::<FieldElement>::U8(1u8),
MemoryValue::<FieldElement>::U8(8u8),
IntegerBitSize::U8
),
Err(BrilligArithmeticError::BitshiftOverflow { bit_size: 8, shift_size: 8 })
);
}
}