acvm/pwg/blackbox/
logic.rs

1use crate::OpcodeResolutionError;
2use crate::pwg::{check_bit_size, input_to_value, insert_value};
3use acir::circuit::opcodes::FunctionInput;
4use acir::{
5    AcirField,
6    native_types::{Witness, WitnessMap},
7};
8use acvm_blackbox_solver::{bit_and, bit_xor};
9
10/// Solves a [`BlackBoxFunc::And`][acir::circuit::black_box_functions::BlackBoxFunc::AND] opcode and inserts
11/// the result into the supplied witness map
12pub(super) fn and<F: AcirField>(
13    initial_witness: &mut WitnessMap<F>,
14    lhs: &FunctionInput<F>,
15    rhs: &FunctionInput<F>,
16    num_bits: u32,
17    output: &Witness,
18) -> Result<(), OpcodeResolutionError<F>> {
19    solve_logic_opcode(initial_witness, lhs, rhs, num_bits, *output, |left, right| {
20        bit_and(left, right, num_bits)
21    })
22}
23
24/// Solves a [`BlackBoxFunc::XOR`][acir::circuit::black_box_functions::BlackBoxFunc::XOR] opcode and inserts
25/// the result into the supplied witness map
26pub(super) fn xor<F: AcirField>(
27    initial_witness: &mut WitnessMap<F>,
28    lhs: &FunctionInput<F>,
29    rhs: &FunctionInput<F>,
30    num_bits: u32,
31    output: &Witness,
32) -> Result<(), OpcodeResolutionError<F>> {
33    solve_logic_opcode(initial_witness, lhs, rhs, num_bits, *output, |left, right| {
34        bit_xor(left, right, num_bits)
35    })
36}
37
38/// Derives the rest of the witness based on the initial low level variables
39fn solve_logic_opcode<F: AcirField>(
40    initial_witness: &mut WitnessMap<F>,
41    a: &FunctionInput<F>,
42    b: &FunctionInput<F>,
43    num_bits: u32,
44    result: Witness,
45    logic_op: impl Fn(F, F) -> F,
46) -> Result<(), OpcodeResolutionError<F>> {
47    let w_l_value = input_to_value(initial_witness, *a)?;
48    let w_r_value = input_to_value(initial_witness, *b)?;
49    let assignment = logic_op(w_l_value, w_r_value);
50    check_bit_size(w_l_value, num_bits)?;
51    check_bit_size(w_r_value, num_bits)?;
52
53    insert_value(&result, assignment, initial_witness)
54}
55
56#[cfg(test)]
57mod tests {
58    use crate::pwg::blackbox::{and, xor};
59    use acir::{
60        FieldElement, InvalidInputBitSize,
61        circuit::opcodes::FunctionInput,
62        native_types::{Witness, WitnessMap},
63    };
64    use std::collections::BTreeMap;
65
66    mod and {
67        use super::*;
68
69        #[test]
70        fn smoke_test() {
71            let lhs = FunctionInput::Witness(Witness(1));
72            let rhs = FunctionInput::Witness(Witness(2));
73
74            let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
75                (Witness(1), FieldElement::from(5u128)),
76                (Witness(2), FieldElement::from(8u128)),
77            ]));
78            and(&mut initial_witness, &lhs, &rhs, 8, &Witness(3)).unwrap();
79            assert_eq!(initial_witness[&Witness(3)], FieldElement::from(0u128));
80
81            let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
82                (Witness(1), FieldElement::from(26u128)),
83                (Witness(2), FieldElement::from(34u128)),
84            ]));
85            and(&mut initial_witness, &lhs, &rhs, 8, &Witness(3)).unwrap();
86            assert_eq!(initial_witness[&Witness(3)], FieldElement::from(2u128));
87        }
88
89        #[test]
90        fn errors_if_input_is_too_large() {
91            let lhs = FunctionInput::Witness(Witness(1));
92            let rhs = FunctionInput::Witness(Witness(2));
93
94            let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
95                (Witness(1), FieldElement::from(5u128)),
96                (Witness(2), FieldElement::from(256u128)),
97            ]));
98            let result = and(&mut initial_witness, &lhs, &rhs, 8, &Witness(3));
99            assert_eq!(
100                result,
101                Err(crate::pwg::OpcodeResolutionError::InvalidInputBitSize {
102                    opcode_location: crate::pwg::ErrorLocation::Unresolved,
103                    invalid_input_bit_size: InvalidInputBitSize {
104                        value: "256".to_string(),
105                        value_num_bits: 9,
106                        max_bits: 8
107                    },
108                })
109            );
110        }
111    }
112
113    mod xor {
114        use super::*;
115
116        #[test]
117        fn test_xor() {
118            let lhs = FunctionInput::Witness(Witness(1));
119            let rhs = FunctionInput::Witness(Witness(2));
120
121            let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
122                (Witness(1), FieldElement::from(5u128)),
123                (Witness(2), FieldElement::from(8u128)),
124            ]));
125            xor(&mut initial_witness, &lhs, &rhs, 8, &Witness(3)).unwrap();
126            assert_eq!(initial_witness[&Witness(3)], FieldElement::from(13u128));
127
128            let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
129                (Witness(1), FieldElement::from(26u128)),
130                (Witness(2), FieldElement::from(34u128)),
131            ]));
132            xor(&mut initial_witness, &lhs, &rhs, 8, &Witness(3)).unwrap();
133            assert_eq!(initial_witness[&Witness(3)], FieldElement::from(56u128));
134        }
135
136        #[test]
137        fn errors_if_input_is_too_large() {
138            let lhs = FunctionInput::Witness(Witness(1));
139            let rhs = FunctionInput::Witness(Witness(2));
140
141            let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
142                (Witness(1), FieldElement::from(5u128)),
143                (Witness(2), FieldElement::from(256u128)),
144            ]));
145            let result = xor(&mut initial_witness, &lhs, &rhs, 8, &Witness(3));
146            assert_eq!(
147                result,
148                Err(crate::pwg::OpcodeResolutionError::InvalidInputBitSize {
149                    opcode_location: crate::pwg::ErrorLocation::Unresolved,
150                    invalid_input_bit_size: InvalidInputBitSize {
151                        value: "256".to_string(),
152                        value_num_bits: 9,
153                        max_bits: 8
154                    },
155                })
156            );
157        }
158    }
159}