1use acir::{
2 AcirField,
3 circuit::opcodes::FunctionInput,
4 native_types::{Witness, WitnessMap},
5};
6use acvm_blackbox_solver::{BlackBoxFunctionSolver, BlackBoxResolutionError, sha256_compression};
7
8use crate::OpcodeResolutionError;
9use crate::pwg::{input_to_value, insert_value};
10
11pub(super) fn solve_generic_256_hash_opcode<F: AcirField>(
14 initial_witness: &mut WitnessMap<F>,
15 inputs: &[FunctionInput<F>],
16 var_message_size: Option<&FunctionInput<F>>,
17 outputs: &[Witness; 32],
18 hash_function: fn(data: &[u8]) -> Result<[u8; 32], BlackBoxResolutionError>,
19) -> Result<(), OpcodeResolutionError<F>> {
20 let message_input = get_hash_input(initial_witness, inputs, var_message_size, 8)?;
21 let digest: [u8; 32] = hash_function(&message_input)?;
22
23 write_digest_to_outputs(initial_witness, outputs, digest)
24}
25
26pub(crate) fn get_hash_input<F: AcirField>(
28 initial_witness: &WitnessMap<F>,
29 inputs: &[FunctionInput<F>],
30 message_size: Option<&FunctionInput<F>>,
31 num_bits: usize,
32) -> Result<Vec<u8>, OpcodeResolutionError<F>> {
33 let mut message_input = Vec::new();
35 for input in inputs {
36 let witness_assignment = input_to_value(initial_witness, *input)?;
37 let bytes = witness_assignment.fetch_nearest_bytes(num_bits);
38 message_input.extend(bytes);
39 }
40
41 match message_size {
43 Some(input) => {
44 let num_bytes_to_take = input_to_value(initial_witness, *input)?
45 .try_into_u128()
46 .map(|num_bytes_to_take| num_bytes_to_take as usize)
47 .expect("expected a 'num_bytes_to_take' that fit into a u128");
48
49 if num_bytes_to_take > message_input.len() {
52 return Err(OpcodeResolutionError::BlackBoxFunctionFailed(
53 acir::BlackBoxFunc::Blake2s,
54 format!(
55 "the number of bytes to take from the message is more than the number of bytes in the message. {} > {}",
56 num_bytes_to_take,
57 message_input.len()
58 ),
59 ));
60 }
61 let truncated_message = message_input[0..num_bytes_to_take].to_vec();
62 Ok(truncated_message)
63 }
64 None => Ok(message_input),
65 }
66}
67
68fn write_digest_to_outputs<F: AcirField>(
70 initial_witness: &mut WitnessMap<F>,
71 outputs: &[Witness; 32],
72 digest: [u8; 32],
73) -> Result<(), OpcodeResolutionError<F>> {
74 for (output_witness, value) in outputs.iter().zip(digest.into_iter()) {
75 insert_value(output_witness, F::from_be_bytes_reduce(&[value]), initial_witness)?;
76 }
77
78 Ok(())
79}
80
81fn to_u32_array<const N: usize, F: AcirField>(
82 initial_witness: &WitnessMap<F>,
83 inputs: &[FunctionInput<F>; N],
84) -> Result<[u32; N], OpcodeResolutionError<F>> {
85 let mut result = [0; N];
86 for (it, input) in result.iter_mut().zip(inputs) {
87 let witness_value = input_to_value(initial_witness, *input)?;
88 *it = witness_value
89 .try_into_u128()
90 .expect("expected the 'witness_value' to fit into a u128")
91 .try_into()
92 .expect("expected the 'witness_value' to fit into a u32");
93 }
94 Ok(result)
95}
96
97pub(crate) fn solve_sha_256_permutation_opcode<F: AcirField>(
98 initial_witness: &mut WitnessMap<F>,
99 inputs: &[FunctionInput<F>; 16],
100 hash_values: &[FunctionInput<F>; 8],
101 outputs: &[Witness; 8],
102) -> Result<(), OpcodeResolutionError<F>> {
103 let state = execute_sha_256_permutation_opcode(initial_witness, inputs, hash_values)?;
104
105 for (output_witness, value) in outputs.iter().zip(state.into_iter()) {
106 insert_value(output_witness, F::from(u128::from(value)), initial_witness)?;
107 }
108
109 Ok(())
110}
111
112pub(crate) fn execute_sha_256_permutation_opcode<F: AcirField>(
113 initial_witness: &WitnessMap<F>,
114 inputs: &[FunctionInput<F>; 16],
115 hash_values: &[FunctionInput<F>; 8],
116) -> Result<[u32; 8], OpcodeResolutionError<F>> {
117 let message = to_u32_array(initial_witness, inputs)?;
118 let mut state = to_u32_array(initial_witness, hash_values)?;
119
120 sha256_compression(&mut state, &message);
121
122 Ok(state)
123}
124
125pub(crate) fn solve_poseidon2_permutation_opcode<F: AcirField>(
126 backend: &impl BlackBoxFunctionSolver<F>,
127 initial_witness: &mut WitnessMap<F>,
128 inputs: &[FunctionInput<F>],
129 outputs: &[Witness],
130) -> Result<(), OpcodeResolutionError<F>> {
131 if inputs.len() != outputs.len() {
132 return Err(OpcodeResolutionError::BlackBoxFunctionFailed(
133 acir::BlackBoxFunc::Poseidon2Permutation,
134 format!(
135 "the input and output sizes are not consistent. {} != {}",
136 inputs.len(),
137 outputs.len()
138 ),
139 ));
140 }
141
142 let state = execute_poseidon2_permutation_opcode(backend, initial_witness, inputs)?;
143
144 for (output_witness, value) in outputs.iter().zip(state.into_iter()) {
146 insert_value(output_witness, value, initial_witness)?;
147 }
148 Ok(())
149}
150
151pub(crate) fn execute_poseidon2_permutation_opcode<F: AcirField>(
152 backend: &impl BlackBoxFunctionSolver<F>,
153 initial_witness: &WitnessMap<F>,
154 inputs: &[FunctionInput<F>],
155) -> Result<Vec<F>, OpcodeResolutionError<F>> {
156 let state: Vec<F> = inputs
158 .iter()
159 .map(|input| input_to_value(initial_witness, *input))
160 .collect::<Result<_, _>>()?;
161
162 let state = backend.poseidon2_permutation(&state)?;
163 Ok(state)
164}
165
166#[cfg(test)]
167mod tests {
168 use crate::pwg::blackbox::solve_generic_256_hash_opcode;
169 use acir::{
170 FieldElement,
171 circuit::opcodes::FunctionInput,
172 native_types::{Witness, WitnessMap},
173 };
174 use acvm_blackbox_solver::{blake2s, blake3};
175 use std::collections::BTreeMap;
176
177 #[test]
178 fn test_blake2s() {
179 let mut inputs = Vec::new();
181 for i in 0..3 {
182 inputs.push(FunctionInput::Witness(Witness(1 + i)));
183 }
184 let mut outputs = [Witness(0); 32];
185 #[allow(clippy::needless_range_loop)]
186 for i in 0..32 {
187 outputs[i] = Witness(4 + i as u32);
188 }
189
190 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
191 (Witness(1), FieldElement::from('a' as u128)),
192 (Witness(2), FieldElement::from('b' as u128)),
193 (Witness(3), FieldElement::from('c' as u128)),
194 ]));
195
196 solve_generic_256_hash_opcode(&mut initial_witness, &inputs, None, &outputs, blake2s)
197 .unwrap();
198
199 let expected_output: [u128; 32] = [
200 0x50, 0x8C, 0x5E, 0x8C, 0x32, 0x7C, 0x14, 0xE2, 0xE1, 0xA7, 0x2B, 0xA3, 0x4E, 0xEB,
201 0x45, 0x2F, 0x37, 0x45, 0x8B, 0x20, 0x9E, 0xD6, 0x3A, 0x29, 0x4D, 0x99, 0x9B, 0x4C,
202 0x86, 0x67, 0x59, 0x82,
203 ];
204 let expected_output = expected_output.map(FieldElement::from);
205 let expected_output: Vec<&FieldElement> = expected_output.iter().collect();
206 for i in 0..32 {
207 assert_eq!(initial_witness[&Witness(4 + i as u32)], *expected_output[i]);
208 }
209 }
210
211 #[test]
212 fn test_blake3s() {
213 let mut inputs = Vec::new();
215 for i in 0..3 {
216 inputs.push(FunctionInput::Witness(Witness(1 + i)));
217 }
218 let mut outputs = [Witness(0); 32];
219 #[allow(clippy::needless_range_loop)]
220 for i in 0..32 {
221 outputs[i] = Witness(4 + i as u32);
222 }
223
224 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
225 (Witness(1), FieldElement::from('a' as u128)),
226 (Witness(2), FieldElement::from('b' as u128)),
227 (Witness(3), FieldElement::from('c' as u128)),
228 ]));
229
230 solve_generic_256_hash_opcode(&mut initial_witness, &inputs, None, &outputs, blake3)
231 .unwrap();
232
233 let expected_output: [u128; 32] = [
234 0x64, 0x37, 0xB3, 0xAC, 0x38, 0x46, 0x51, 0x33, 0xFF, 0xB6, 0x3B, 0x75, 0x27, 0x3A,
235 0x8D, 0xB5, 0x48, 0xC5, 0x58, 0x46, 0x5D, 0x79, 0xDB, 0x03, 0xFD, 0x35, 0x9C, 0x6C,
236 0xD5, 0xBD, 0x9D, 0x85,
237 ];
238 let expected_output = expected_output.map(FieldElement::from);
239 let expected_output: Vec<&FieldElement> = expected_output.iter().collect();
240 for i in 0..32 {
241 assert_eq!(initial_witness[&Witness(4 + i as u32)], *expected_output[i]);
242 }
243 }
244}