acvm_blackbox_solver/
logic.rs

1use acir::AcirField;
2
3pub fn bit_and<F: AcirField>(lhs: F, rhs: F, num_bits: u32) -> F {
4    // Fast path: use native u128 operations when the bit width fits,
5    // avoiding all heap allocations from field-to-byte conversions.
6    if let Some(result) = try_bitwise_u128(lhs, rhs, num_bits, |l, r| l & r) {
7        result
8    } else {
9        bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte & rhs_byte)
10    }
11}
12
13pub fn bit_xor<F: AcirField>(lhs: F, rhs: F, num_bits: u32) -> F {
14    // Fast path: use native u128 operations when the bit width fits,
15    // avoiding all heap allocations from field-to-byte conversions.
16    if let Some(result) = try_bitwise_u128(lhs, rhs, num_bits, |l, r| l ^ r) {
17        result
18    } else {
19        bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte ^ rhs_byte)
20    }
21}
22
23/// Attempt to perform a bitwise operation using native u128 arithmetic.
24/// Returns `None` if `num_bits > 128` or either operand doesn't fit in a u128.
25fn try_bitwise_u128<F: AcirField>(
26    lhs: F,
27    rhs: F,
28    num_bits: u32,
29    op: fn(u128, u128) -> u128,
30) -> Option<F> {
31    if num_bits > 128 {
32        return None;
33    }
34    let l = lhs.try_into_u128()?;
35    let r = rhs.try_into_u128()?;
36    let mask = if num_bits >= 128 { u128::MAX } else { (1u128 << num_bits) - 1 };
37    Some(F::from(op(l, r) & mask))
38}
39
40/// Performs a bitwise operation on two field elements by treating them as byte arrays.
41///
42/// Both field elements are converted to little-endian byte arrays and masked to keep only
43/// the lowest `num_bits` bits. The provided operation `op` is then applied byte-by-byte,
44/// and the result is converted back to a field element.
45/// This function works for any `num_bits` value and does not assume it to be a multiple of 8.
46fn bitwise_op<F: AcirField>(lhs: F, rhs: F, num_bits: u32, op: fn(u8, u8) -> u8) -> F {
47    let mut lhs_bytes = mask_to_le_bytes(lhs, num_bits);
48    let rhs_bytes = mask_to_le_bytes(rhs, num_bits);
49
50    // Operate in-place on lhs_bytes to avoid allocating a third Vec.
51    for (l, r) in lhs_bytes.iter_mut().zip(rhs_bytes.iter()) {
52        *l = op(*l, *r);
53    }
54
55    F::from_le_bytes_reduce(&lhs_bytes)
56}
57
58// mask_to methods will not remove any bytes from the field
59// they are simply zeroed out
60// Whereas truncate_to will remove those bits and make the byte array smaller
61fn mask_to_le_bytes<F: AcirField>(field: F, num_bits: u32) -> Vec<u8> {
62    let mut bytes = field.to_le_bytes();
63    mask_vector_le(&mut bytes, num_bits as usize);
64    bytes
65}
66
67/// Mask a byte array in-place to only keep the lowest `num_bits`
68fn mask_vector_le(bytes: &mut [u8], num_bits: usize) {
69    let total_bits = bytes.len() * 8;
70    if num_bits >= total_bits {
71        // nothing to mask, all bits are used
72        return;
73    }
74
75    // Find which byte contains the last bit we want to keep
76    let array_mask_index = num_bits / 8;
77    // Find how many bits to keep in that byte (0-7)
78    let mask_power = num_bits % 8;
79
80    // If `mask_power` is non-zero, this keeps only the lower `mask_power` bits of the byte.
81    // If `mask_power` is zero (when `num_bits` is a multiple of 8), this zeros out the byte,
82    // which is correct since that byte is the first one beyond what we want to keep.
83    bytes[array_mask_index] &= 2u8.pow(mask_power as u32) - 1;
84
85    // Zero out all remaining bytes
86    for byte in &mut bytes[(array_mask_index + 1)..] {
87        *byte = 0;
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use acir::FieldElement;
94    use proptest::prelude::*;
95
96    use crate::{bit_and, bit_xor};
97
98    proptest! {
99        #[test]
100        fn matches_bitwise_and_on_u128s(x in 0..=u128::MAX, y in 0..=u128::MAX, bit_size in 128u32..) {
101            let x_as_field = FieldElement::from(x);
102            let y_as_field = FieldElement::from(y);
103
104            let x_and_y = x & y;
105            let x_and_y_as_field = bit_and(x_as_field, y_as_field, bit_size);
106
107            prop_assert_eq!(x_and_y_as_field, FieldElement::from(x_and_y), "AND on fields should match that on integers");
108        }
109
110        #[test]
111        fn matches_bitwise_xor_on_u128s(x in 0..=u128::MAX, y in 0..=u128::MAX, bit_size in 128u32..) {
112            let x_as_field = FieldElement::from(x);
113            let y_as_field = FieldElement::from(y);
114
115            let x_xor_y = x ^ y;
116            let x_xor_y_as_field = bit_xor(x_as_field, y_as_field, bit_size);
117
118            prop_assert_eq!(x_xor_y_as_field, FieldElement::from(x_xor_y), "XOR on fields should match that on integers");
119        }
120    }
121}