1use std::collections::HashMap;
2
3use acir::{
4 AcirField,
5 brillig::{ForeignCallParam, ForeignCallResult, Opcode as BrilligOpcode},
6 circuit::{
7 OpcodeLocation,
8 brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
9 opcodes::BlockId,
10 },
11 native_types::WitnessMap,
12};
13use acvm_blackbox_solver::BlackBoxFunctionSolver;
14use brillig_vm::{
15 BranchToFeatureMap, BrilligProfilingSamples, FailureReason, MemoryValue, VM, VMStatus,
16};
17use serde::{Deserialize, Serialize};
18
19use crate::{OpcodeResolutionError, pwg::OpcodeNotSolvable};
20
21use super::{
22 ErrorSelector, RawAssertionPayload, ResolvedAssertionPayload, get_value, insert_value,
23 memory_op::MemoryOpSolver,
24};
25
26#[derive(Debug)]
27pub enum BrilligSolverStatus<F> {
28 Finished,
29 InProgress,
30 ForeignCallWait(ForeignCallWaitInfo<F>),
31}
32
33pub struct BrilligSolver<'b, F, B: BlackBoxFunctionSolver<F>> {
36 vm: VM<'b, F, B>,
37 acir_index: usize,
38 pub function_id: BrilligFunctionId,
42}
43
44impl<'b, B: BlackBoxFunctionSolver<F>, F: AcirField> BrilligSolver<'b, F, B> {
45 pub(super) fn zero_out_brillig_outputs(
47 initial_witness: &mut WitnessMap<F>,
48 outputs: &[BrilligOutputs],
49 ) -> Result<(), OpcodeResolutionError<F>> {
50 for output in outputs {
51 match output {
52 BrilligOutputs::Simple(witness) => {
53 insert_value(witness, F::zero(), initial_witness)?;
54 }
55 BrilligOutputs::Array(witness_arr) => {
56 for witness in witness_arr {
57 insert_value(witness, F::zero(), initial_witness)?;
58 }
59 }
60 }
61 }
62 Ok(())
63 }
64
65 #[allow(clippy::too_many_arguments)]
68 pub(crate) fn new_call(
69 initial_witness: &WitnessMap<F>,
70 memory: &HashMap<BlockId, MemoryOpSolver<F>>,
71 inputs: &'b [BrilligInputs<F>],
72 brillig_bytecode: &'b [BrilligOpcode<F>],
73 bb_solver: &'b B,
74 acir_index: usize,
75 brillig_function_id: BrilligFunctionId,
76 profiling_active: bool,
77 with_branch_to_feature_map: Option<&BranchToFeatureMap>,
78 ) -> Result<Self, OpcodeResolutionError<F>> {
79 let vm = Self::setup_brillig_vm(
80 initial_witness,
81 memory,
82 inputs,
83 brillig_bytecode,
84 bb_solver,
85 profiling_active,
86 with_branch_to_feature_map,
87 )?;
88 Ok(Self { vm, acir_index, function_id: brillig_function_id })
89 }
90
91 fn setup_brillig_vm(
95 initial_witness: &WitnessMap<F>,
96 memory: &HashMap<BlockId, MemoryOpSolver<F>>,
97 inputs: &[BrilligInputs<F>],
98 brillig_bytecode: &'b [BrilligOpcode<F>],
99 bb_solver: &'b B,
100 profiling_active: bool,
101 with_branch_to_feature_map: Option<&BranchToFeatureMap>,
102 ) -> Result<VM<'b, F, B>, OpcodeResolutionError<F>> {
103 let mut calldata: Vec<F> = Vec::new();
105 for input in inputs {
110 match input {
111 BrilligInputs::Single(expr) => match get_value(expr, initial_witness) {
112 Ok(value) => calldata.push(value),
113 Err(_) => {
114 return Err(OpcodeResolutionError::OpcodeNotSolvable(
115 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(expr.clone()),
116 ));
117 }
118 },
119 BrilligInputs::Array(expr_arr) => {
120 for expr in expr_arr {
122 match get_value(expr, initial_witness) {
123 Ok(value) => calldata.push(value),
124 Err(_) => {
125 return Err(OpcodeResolutionError::OpcodeNotSolvable(
126 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(expr.clone()),
127 ));
128 }
129 }
130 }
131 }
132 BrilligInputs::MemoryArray(block_id) => {
133 let memory_block = memory
134 .get(block_id)
135 .ok_or(OpcodeNotSolvable::MissingMemoryBlock(block_id.0))?;
136 calldata.extend(&memory_block.block_value);
137 }
138 }
139 }
140
141 let vm = VM::new(
144 calldata,
145 brillig_bytecode,
146 bb_solver,
147 profiling_active,
148 with_branch_to_feature_map,
149 );
150 Ok(vm)
151 }
152
153 pub fn get_memory(&self) -> &[MemoryValue<F>] {
154 self.vm.get_memory()
155 }
156
157 pub fn write_memory_at(&mut self, ptr: u32, value: MemoryValue<F>) {
158 self.vm.write_memory_at(ptr, value);
159 }
160
161 pub fn get_call_stack(&self) -> Vec<usize> {
162 self.vm.get_call_stack()
163 }
164
165 pub fn get_fuzzing_trace(&self) -> Vec<u32> {
166 self.vm.get_fuzzing_trace()
167 }
168
169 pub(crate) fn solve(&mut self) -> Result<BrilligSolverStatus<F>, OpcodeResolutionError<F>> {
170 let status = self.vm.process_opcodes();
171 self.handle_vm_status(status)
172 }
173
174 pub fn step(&mut self) -> Result<BrilligSolverStatus<F>, OpcodeResolutionError<F>> {
175 let status = self.vm.process_opcode().clone();
176 self.handle_vm_status(status)
177 }
178
179 pub fn program_counter(&self) -> usize {
180 self.vm.program_counter()
181 }
182
183 fn handle_vm_status(
188 &self,
189 vm_status: VMStatus<F>,
190 ) -> Result<BrilligSolverStatus<F>, OpcodeResolutionError<F>> {
191 match vm_status {
192 VMStatus::Finished { .. } => Ok(BrilligSolverStatus::Finished),
193 VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress),
194 VMStatus::Failure { reason, call_stack } => {
195 let call_stack = call_stack
196 .iter()
197 .map(|brillig_index| OpcodeLocation::Brillig {
198 acir_index: self.acir_index,
199 brillig_index: *brillig_index,
200 })
201 .collect();
202 let payload = match reason {
203 FailureReason::RuntimeError { message } => {
204 Some(ResolvedAssertionPayload::String(message))
205 }
206 FailureReason::Trap { revert_data_offset, revert_data_size } => {
207 extract_failure_payload_from_memory(
208 self.vm.get_memory(),
209 revert_data_offset
210 .try_into()
211 .expect("Failed conversion from u32 to usize"),
212 revert_data_size
213 .try_into()
214 .expect("Failed conversion from u32 to usize"),
215 )
216 }
217 };
218
219 Err(OpcodeResolutionError::BrilligFunctionFailed {
220 function_id: self.function_id,
221 payload,
222 call_stack,
223 })
224 }
225 VMStatus::ForeignCallWait { function, inputs } => {
226 Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs }))
227 }
228 }
229 }
230
231 pub(crate) fn finalize(
232 self,
233 witness: &mut WitnessMap<F>,
234 outputs: &[BrilligOutputs],
235 ) -> Result<(), OpcodeResolutionError<F>> {
236 assert!(!self.vm.is_profiling_active(), "Expected VM profiling to not be active");
237 self.finalize_inner(witness, outputs)
238 }
239
240 pub(crate) fn finalize_with_profiling(
242 mut self,
243 witness: &mut WitnessMap<F>,
244 outputs: &[BrilligOutputs],
245 ) -> Result<BrilligProfilingSamples, OpcodeResolutionError<F>> {
246 assert!(self.vm.is_profiling_active(), "Expected VM profiling to be active");
247 self.finalize_inner(witness, outputs)?;
248 Ok(self.vm.take_profiling_samples())
249 }
250
251 fn finalize_inner(
253 &self,
254 witness: &mut WitnessMap<F>,
255 outputs: &[BrilligOutputs],
256 ) -> Result<(), OpcodeResolutionError<F>> {
257 let vm_status = self.vm.get_status();
259 match vm_status {
260 VMStatus::Finished { return_data_offset, return_data_size } => {
261 self.write_brillig_outputs(
262 witness,
263 return_data_offset.try_into().expect("Failed conversion from u32 to usize"),
264 return_data_size.try_into().expect("Failed conversion from u32 to usize"),
265 outputs,
266 )?;
267 Ok(())
268 }
269 _ => panic!("Brillig VM has not completed execution"),
270 }
271 }
272
273 fn write_brillig_outputs(
275 &self,
276 witness_map: &mut WitnessMap<F>,
277 return_data_offset: usize,
278 return_data_size: usize,
279 outputs: &[BrilligOutputs],
280 ) -> Result<(), OpcodeResolutionError<F>> {
281 let memory = self.vm.get_memory();
282 let mut current_ret_data_idx = return_data_offset;
283 for output in outputs {
284 match output {
285 BrilligOutputs::Simple(witness) => {
286 let value = memory
287 .get(current_ret_data_idx)
288 .expect("Return data index exceeds memory bounds");
289 insert_value(witness, value.to_field(), witness_map)?;
290 current_ret_data_idx =
291 current_ret_data_idx.checked_add(1).expect("Return data index overflow");
292 }
293 BrilligOutputs::Array(witness_arr) => {
294 for witness in witness_arr {
295 let value = memory
296 .get(current_ret_data_idx)
297 .expect("Return data index exceeds memory bounds");
298 insert_value(witness, value.to_field(), witness_map)?;
299 current_ret_data_idx = current_ret_data_idx
300 .checked_add(1)
301 .expect("Return data index overflow");
302 }
303 }
304 }
305 }
306
307 let expected_end = return_data_offset
308 .checked_add(return_data_size)
309 .expect("Return data offset and size overflow");
310 assert!(
311 current_ret_data_idx == expected_end,
312 "Brillig VM did not write the expected number of return values"
313 );
314 Ok(())
315 }
316
317 pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult<F>) {
318 match self.vm.get_status() {
319 VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result),
320 _ => unreachable!("Brillig VM is not waiting for a foreign call"),
321 }
322 }
323}
324
325fn extract_failure_payload_from_memory<F: AcirField>(
329 memory: &[MemoryValue<F>],
330 revert_data_offset: usize,
331 revert_data_size: usize,
332) -> Option<ResolvedAssertionPayload<F>> {
333 if revert_data_size == 0 {
335 None
336 } else {
337 let end = revert_data_offset
338 .checked_add(revert_data_size)
339 .expect("Revert data offset and size overflow");
340 let mut revert_values_iter = memory
341 .get(revert_data_offset..end)
342 .expect("Revert data offset and size exceed memory bounds")
343 .iter();
344 let error_selector = ErrorSelector::new(
345 revert_values_iter
346 .next()
347 .copied()
348 .expect("Incorrect revert data size")
349 .try_into()
350 .expect("Error selector is not u64"),
351 );
352
353 Some(ResolvedAssertionPayload::Raw(RawAssertionPayload {
354 selector: error_selector,
355 data: revert_values_iter.map(|value| value.to_field()).collect(),
356 }))
357 }
358}
359
360#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
365pub struct ForeignCallWaitInfo<F> {
366 pub function: String,
368 pub inputs: Vec<ForeignCallParam<F>>,
370}
371
372#[cfg(test)]
373mod tests {
374 use crate::pwg::BrilligSolver;
375 use acir::{
376 FieldElement,
377 brillig::{BinaryFieldOp, BitSize, HeapVector, IntegerBitSize, MemoryAddress, Opcode},
378 circuit::brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
379 native_types::{Expression, Witness, WitnessMap},
380 };
381 use std::collections::{BTreeMap, HashMap};
382
383 #[test]
384 fn test_solver() {
385 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
386 (Witness(1), FieldElement::from(1u128)),
387 (Witness(2), FieldElement::from(1u128)),
388 (Witness(3), FieldElement::from(2u128)),
389 ]));
390 let w1 = Expression::from(Witness(1));
391 let w2 = Expression::from(Witness(2));
392 let w3 = Expression::from(Witness(3));
393 let inputs =
394 vec![BrilligInputs::Single(w1), BrilligInputs::Single(w2), BrilligInputs::Single(w3)];
395
396 let backend = acvm_blackbox_solver::StubbedBlackBoxSolver;
397 let bytecode = vec![
398 Opcode::Const {
399 destination: MemoryAddress::Direct(21),
400 bit_size: BitSize::Integer(IntegerBitSize::U32),
401 value: FieldElement::from(1_u128),
402 },
403 Opcode::Const {
404 destination: MemoryAddress::Direct(20),
405 bit_size: BitSize::Integer(IntegerBitSize::U32),
406 value: FieldElement::from(0_u128),
407 },
408 Opcode::CalldataCopy {
409 destination_address: MemoryAddress::Direct(0),
410 size_address: MemoryAddress::Direct(21),
411 offset_address: MemoryAddress::Direct(20),
412 },
413 Opcode::Const {
414 destination: MemoryAddress::Direct(2),
415 bit_size: BitSize::Field,
416 value: FieldElement::from(0_u128),
417 },
418 Opcode::BinaryFieldOp {
419 destination: MemoryAddress::Direct(3),
420 op: BinaryFieldOp::Equals,
421 lhs: MemoryAddress::Direct(0),
422 rhs: MemoryAddress::Direct(2),
423 },
424 Opcode::JumpIf { condition: MemoryAddress::Direct(3), location: 8 },
425 Opcode::Const {
426 destination: MemoryAddress::Direct(1),
427 bit_size: BitSize::Field,
428 value: FieldElement::from(1_u128),
429 },
430 Opcode::BinaryFieldOp {
431 destination: MemoryAddress::Direct(0),
432 op: BinaryFieldOp::Add,
433 lhs: MemoryAddress::Direct(1),
434 rhs: MemoryAddress::Direct(0),
435 },
436 Opcode::Stop {
437 return_data: HeapVector {
438 pointer: MemoryAddress::Direct(20),
439 size: MemoryAddress::Direct(21),
440 },
441 },
442 ];
443 let mut solver = BrilligSolver::new_call(
444 &initial_witness,
445 &HashMap::new(),
446 &inputs,
447 &bytecode,
448 &backend,
449 0,
450 BrilligFunctionId::default(),
451 false,
452 None,
453 )
454 .unwrap();
455 solver.solve().unwrap();
456 let outputs = vec![BrilligOutputs::Simple(Witness(4))];
457 solver.finalize(&mut initial_witness, &outputs).unwrap();
458
459 assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
460 }
461}