acvm/pwg/blackbox/
hash.rs

1use acir::{
2    AcirField,
3    circuit::opcodes::FunctionInput,
4    native_types::{Witness, WitnessMap},
5};
6use acvm_blackbox_solver::{BlackBoxFunctionSolver, BlackBoxResolutionError, sha256_compression};
7
8use crate::OpcodeResolutionError;
9use crate::pwg::{input_to_value, insert_value};
10
11/// Attempts to solve a 256 bit hash function opcode.
12/// If successful, `initial_witness` will be mutated to contain the new witness assignment.
13pub(super) fn solve_generic_256_hash_opcode<F: AcirField>(
14    initial_witness: &mut WitnessMap<F>,
15    inputs: &[FunctionInput<F>],
16    var_message_size: Option<&FunctionInput<F>>,
17    outputs: &[Witness; 32],
18    hash_function: fn(data: &[u8]) -> Result<[u8; 32], BlackBoxResolutionError>,
19) -> Result<(), OpcodeResolutionError<F>> {
20    let message_input = get_hash_input(initial_witness, inputs, var_message_size, 8)?;
21    let digest: [u8; 32] = hash_function(&message_input)?;
22
23    write_digest_to_outputs(initial_witness, outputs, digest)
24}
25
26/// Reads the hash function input from a [`WitnessMap`].
27pub(crate) fn get_hash_input<F: AcirField>(
28    initial_witness: &WitnessMap<F>,
29    inputs: &[FunctionInput<F>],
30    message_size: Option<&FunctionInput<F>>,
31    num_bits: usize,
32) -> Result<Vec<u8>, OpcodeResolutionError<F>> {
33    // Read witness assignments.
34    let mut message_input = Vec::new();
35    for input in inputs {
36        let witness_assignment = input_to_value(initial_witness, *input)?;
37        let bytes = witness_assignment.fetch_nearest_bytes(num_bits);
38        message_input.extend(bytes);
39    }
40
41    // Truncate the message if there is a `message_size` parameter given
42    match message_size {
43        Some(input) => {
44            let num_bytes_to_take = input_to_value(initial_witness, *input)?
45                .try_into_u128()
46                .map(|num_bytes_to_take| num_bytes_to_take as usize)
47                .expect("expected a 'num_bytes_to_take' that fit into a u128");
48
49            // If the number of bytes to take is more than the amount of bytes available
50            // in the message, then we error.
51            if num_bytes_to_take > message_input.len() {
52                return Err(OpcodeResolutionError::BlackBoxFunctionFailed(
53                    acir::BlackBoxFunc::Blake2s,
54                    format!(
55                        "the number of bytes to take from the message is more than the number of bytes in the message. {} > {}",
56                        num_bytes_to_take,
57                        message_input.len()
58                    ),
59                ));
60            }
61            let truncated_message = message_input[0..num_bytes_to_take].to_vec();
62            Ok(truncated_message)
63        }
64        None => Ok(message_input),
65    }
66}
67
68/// Writes a `digest` to the [`WitnessMap`] at witness indices `outputs`.
69fn write_digest_to_outputs<F: AcirField>(
70    initial_witness: &mut WitnessMap<F>,
71    outputs: &[Witness; 32],
72    digest: [u8; 32],
73) -> Result<(), OpcodeResolutionError<F>> {
74    for (output_witness, value) in outputs.iter().zip(digest.into_iter()) {
75        insert_value(output_witness, F::from_be_bytes_reduce(&[value]), initial_witness)?;
76    }
77
78    Ok(())
79}
80
81fn to_u32_array<const N: usize, F: AcirField>(
82    initial_witness: &WitnessMap<F>,
83    inputs: &[FunctionInput<F>; N],
84) -> Result<[u32; N], OpcodeResolutionError<F>> {
85    let mut result = [0; N];
86    for (it, input) in result.iter_mut().zip(inputs) {
87        let witness_value = input_to_value(initial_witness, *input)?;
88        *it = witness_value
89            .try_into_u128()
90            .expect("expected the 'witness_value' to fit into a u128")
91            .try_into()
92            .expect("expected the 'witness_value' to fit into a u32");
93    }
94    Ok(result)
95}
96
97pub(crate) fn solve_sha_256_permutation_opcode<F: AcirField>(
98    initial_witness: &mut WitnessMap<F>,
99    inputs: &[FunctionInput<F>; 16],
100    hash_values: &[FunctionInput<F>; 8],
101    outputs: &[Witness; 8],
102) -> Result<(), OpcodeResolutionError<F>> {
103    let state = execute_sha_256_permutation_opcode(initial_witness, inputs, hash_values)?;
104
105    for (output_witness, value) in outputs.iter().zip(state.into_iter()) {
106        insert_value(output_witness, F::from(u128::from(value)), initial_witness)?;
107    }
108
109    Ok(())
110}
111
112pub(crate) fn execute_sha_256_permutation_opcode<F: AcirField>(
113    initial_witness: &WitnessMap<F>,
114    inputs: &[FunctionInput<F>; 16],
115    hash_values: &[FunctionInput<F>; 8],
116) -> Result<[u32; 8], OpcodeResolutionError<F>> {
117    let message = to_u32_array(initial_witness, inputs)?;
118    let mut state = to_u32_array(initial_witness, hash_values)?;
119
120    sha256_compression(&mut state, &message);
121
122    Ok(state)
123}
124
125pub(crate) fn solve_poseidon2_permutation_opcode<F: AcirField>(
126    backend: &impl BlackBoxFunctionSolver<F>,
127    initial_witness: &mut WitnessMap<F>,
128    inputs: &[FunctionInput<F>],
129    outputs: &[Witness],
130) -> Result<(), OpcodeResolutionError<F>> {
131    if inputs.len() != outputs.len() {
132        return Err(OpcodeResolutionError::BlackBoxFunctionFailed(
133            acir::BlackBoxFunc::Poseidon2Permutation,
134            format!(
135                "the input and output sizes are not consistent. {} != {}",
136                inputs.len(),
137                outputs.len()
138            ),
139        ));
140    }
141
142    let state = execute_poseidon2_permutation_opcode(backend, initial_witness, inputs)?;
143
144    // Write witness assignments
145    for (output_witness, value) in outputs.iter().zip(state.into_iter()) {
146        insert_value(output_witness, value, initial_witness)?;
147    }
148    Ok(())
149}
150
151pub(crate) fn execute_poseidon2_permutation_opcode<F: AcirField>(
152    backend: &impl BlackBoxFunctionSolver<F>,
153    initial_witness: &WitnessMap<F>,
154    inputs: &[FunctionInput<F>],
155) -> Result<Vec<F>, OpcodeResolutionError<F>> {
156    // Read witness assignments
157    let state: Vec<F> = inputs
158        .iter()
159        .map(|input| input_to_value(initial_witness, *input))
160        .collect::<Result<_, _>>()?;
161
162    let state = backend.poseidon2_permutation(&state)?;
163    Ok(state)
164}
165
166#[cfg(test)]
167mod tests {
168    use crate::pwg::blackbox::solve_generic_256_hash_opcode;
169    use acir::{
170        FieldElement,
171        circuit::opcodes::FunctionInput,
172        native_types::{Witness, WitnessMap},
173    };
174    use acvm_blackbox_solver::{blake2s, blake3};
175    use std::collections::BTreeMap;
176
177    #[test]
178    fn test_blake2s() {
179        // Test vector is coming from Barretenberg (cf. blake2s.test.cpp)
180        let mut inputs = Vec::new();
181        for i in 0..3 {
182            inputs.push(FunctionInput::Witness(Witness(1 + i)));
183        }
184        let mut outputs = [Witness(0); 32];
185        #[allow(clippy::needless_range_loop)]
186        for i in 0..32 {
187            outputs[i] = Witness(4 + i as u32);
188        }
189
190        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
191            (Witness(1), FieldElement::from('a' as u128)),
192            (Witness(2), FieldElement::from('b' as u128)),
193            (Witness(3), FieldElement::from('c' as u128)),
194        ]));
195
196        solve_generic_256_hash_opcode(&mut initial_witness, &inputs, None, &outputs, blake2s)
197            .unwrap();
198
199        let expected_output: [u128; 32] = [
200            0x50, 0x8C, 0x5E, 0x8C, 0x32, 0x7C, 0x14, 0xE2, 0xE1, 0xA7, 0x2B, 0xA3, 0x4E, 0xEB,
201            0x45, 0x2F, 0x37, 0x45, 0x8B, 0x20, 0x9E, 0xD6, 0x3A, 0x29, 0x4D, 0x99, 0x9B, 0x4C,
202            0x86, 0x67, 0x59, 0x82,
203        ];
204        let expected_output = expected_output.map(FieldElement::from);
205        let expected_output: Vec<&FieldElement> = expected_output.iter().collect();
206        for i in 0..32 {
207            assert_eq!(initial_witness[&Witness(4 + i as u32)], *expected_output[i]);
208        }
209    }
210
211    #[test]
212    fn test_blake3s() {
213        // Test vector is coming from Barretenberg (cf. blake3s.test.cpp)
214        let mut inputs = Vec::new();
215        for i in 0..3 {
216            inputs.push(FunctionInput::Witness(Witness(1 + i)));
217        }
218        let mut outputs = [Witness(0); 32];
219        #[allow(clippy::needless_range_loop)]
220        for i in 0..32 {
221            outputs[i] = Witness(4 + i as u32);
222        }
223
224        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
225            (Witness(1), FieldElement::from('a' as u128)),
226            (Witness(2), FieldElement::from('b' as u128)),
227            (Witness(3), FieldElement::from('c' as u128)),
228        ]));
229
230        solve_generic_256_hash_opcode(&mut initial_witness, &inputs, None, &outputs, blake3)
231            .unwrap();
232
233        let expected_output: [u128; 32] = [
234            0x64, 0x37, 0xB3, 0xAC, 0x38, 0x46, 0x51, 0x33, 0xFF, 0xB6, 0x3B, 0x75, 0x27, 0x3A,
235            0x8D, 0xB5, 0x48, 0xC5, 0x58, 0x46, 0x5D, 0x79, 0xDB, 0x03, 0xFD, 0x35, 0x9C, 0x6C,
236            0xD5, 0xBD, 0x9D, 0x85,
237        ];
238        let expected_output = expected_output.map(FieldElement::from);
239        let expected_output: Vec<&FieldElement> = expected_output.iter().collect();
240        for i in 0..32 {
241            assert_eq!(initial_witness[&Witness(4 + i as u32)], *expected_output[i]);
242        }
243    }
244}