acvm/pwg/
brillig.rs

1use std::collections::HashMap;
2
3use acir::{
4    AcirField,
5    brillig::{ForeignCallParam, ForeignCallResult, Opcode as BrilligOpcode},
6    circuit::{
7        OpcodeLocation,
8        brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
9        opcodes::BlockId,
10    },
11    native_types::WitnessMap,
12};
13use acvm_blackbox_solver::BlackBoxFunctionSolver;
14use brillig_vm::{
15    BranchToFeatureMap, BrilligProfilingSamples, FailureReason, MemoryValue, VM, VMStatus,
16};
17use serde::{Deserialize, Serialize};
18
19use crate::{OpcodeResolutionError, pwg::OpcodeNotSolvable};
20
21use super::{
22    ErrorSelector, RawAssertionPayload, ResolvedAssertionPayload, get_value, insert_value,
23    memory_op::MemoryOpSolver,
24};
25
26#[derive(Debug)]
27pub enum BrilligSolverStatus<F> {
28    Finished,
29    InProgress,
30    ForeignCallWait(ForeignCallWaitInfo<F>),
31}
32
33/// Specific solver for Brillig opcodes
34/// It maintains a Brillig VM that can execute the bytecode of the called brillig function
35pub struct BrilligSolver<'b, F, B: BlackBoxFunctionSolver<F>> {
36    vm: VM<'b, F, B>,
37    acir_index: usize,
38    /// This id references which Brillig function within the main ACIR program we are solving.
39    /// This is used for appropriately resolving errors as the ACIR program artifacts
40    /// set up their Brillig debug metadata by function id.
41    pub function_id: BrilligFunctionId,
42}
43
44impl<'b, B: BlackBoxFunctionSolver<F>, F: AcirField> BrilligSolver<'b, F, B> {
45    /// Assigns the zero value to all outputs of a given [brillig call][acir::circuit::opcodes::Opcode::BrilligCall].
46    pub(super) fn zero_out_brillig_outputs(
47        initial_witness: &mut WitnessMap<F>,
48        outputs: &[BrilligOutputs],
49    ) -> Result<(), OpcodeResolutionError<F>> {
50        for output in outputs {
51            match output {
52                BrilligOutputs::Simple(witness) => {
53                    insert_value(witness, F::zero(), initial_witness)?;
54                }
55                BrilligOutputs::Array(witness_arr) => {
56                    for witness in witness_arr {
57                        insert_value(witness, F::zero(), initial_witness)?;
58                    }
59                }
60            }
61        }
62        Ok(())
63    }
64
65    /// Constructs a solver for a Brillig block given the bytecode and initial
66    /// witness.
67    #[allow(clippy::too_many_arguments)]
68    pub(crate) fn new_call(
69        initial_witness: &WitnessMap<F>,
70        memory: &HashMap<BlockId, MemoryOpSolver<F>>,
71        inputs: &'b [BrilligInputs<F>],
72        brillig_bytecode: &'b [BrilligOpcode<F>],
73        bb_solver: &'b B,
74        acir_index: usize,
75        brillig_function_id: BrilligFunctionId,
76        profiling_active: bool,
77        with_branch_to_feature_map: Option<&BranchToFeatureMap>,
78    ) -> Result<Self, OpcodeResolutionError<F>> {
79        let vm = Self::setup_brillig_vm(
80            initial_witness,
81            memory,
82            inputs,
83            brillig_bytecode,
84            bb_solver,
85            profiling_active,
86            with_branch_to_feature_map,
87        )?;
88        Ok(Self { vm, acir_index, function_id: brillig_function_id })
89    }
90
91    /// Get a BrilligVM for executing the provided bytecode
92    /// 1. Reduce the input expressions into a known value, or error if they do not reduce to a value.
93    /// 2. Instantiate the Brillig VM with the bytecode and the reduced inputs.
94    fn setup_brillig_vm(
95        initial_witness: &WitnessMap<F>,
96        memory: &HashMap<BlockId, MemoryOpSolver<F>>,
97        inputs: &[BrilligInputs<F>],
98        brillig_bytecode: &'b [BrilligOpcode<F>],
99        bb_solver: &'b B,
100        profiling_active: bool,
101        with_branch_to_feature_map: Option<&BranchToFeatureMap>,
102    ) -> Result<VM<'b, F, B>, OpcodeResolutionError<F>> {
103        // Set input values
104        let mut calldata: Vec<F> = Vec::new();
105        // Each input represents an expression or array of expressions to evaluate.
106        // Iterate over each input and evaluate the expression(s) associated with it.
107        // Push the results into memory.
108        // If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution.
109        for input in inputs {
110            match input {
111                BrilligInputs::Single(expr) => match get_value(expr, initial_witness) {
112                    Ok(value) => calldata.push(value),
113                    Err(_) => {
114                        return Err(OpcodeResolutionError::OpcodeNotSolvable(
115                            OpcodeNotSolvable::ExpressionHasTooManyUnknowns(expr.clone()),
116                        ));
117                    }
118                },
119                BrilligInputs::Array(expr_arr) => {
120                    // Attempt to fetch all array input values
121                    for expr in expr_arr {
122                        match get_value(expr, initial_witness) {
123                            Ok(value) => calldata.push(value),
124                            Err(_) => {
125                                return Err(OpcodeResolutionError::OpcodeNotSolvable(
126                                    OpcodeNotSolvable::ExpressionHasTooManyUnknowns(expr.clone()),
127                                ));
128                            }
129                        }
130                    }
131                }
132                BrilligInputs::MemoryArray(block_id) => {
133                    let memory_block = memory
134                        .get(block_id)
135                        .ok_or(OpcodeNotSolvable::MissingMemoryBlock(block_id.0))?;
136                    calldata.extend(&memory_block.block_value);
137                }
138            }
139        }
140
141        // Instantiate a Brillig VM given the solved calldata
142        // along with the Brillig bytecode.
143        let vm = VM::new(
144            calldata,
145            brillig_bytecode,
146            bb_solver,
147            profiling_active,
148            with_branch_to_feature_map,
149        );
150        Ok(vm)
151    }
152
153    pub fn get_memory(&self) -> &[MemoryValue<F>] {
154        self.vm.get_memory()
155    }
156
157    pub fn write_memory_at(&mut self, ptr: u32, value: MemoryValue<F>) {
158        self.vm.write_memory_at(ptr, value);
159    }
160
161    pub fn get_call_stack(&self) -> Vec<usize> {
162        self.vm.get_call_stack()
163    }
164
165    pub fn get_fuzzing_trace(&self) -> Vec<u32> {
166        self.vm.get_fuzzing_trace()
167    }
168
169    pub(crate) fn solve(&mut self) -> Result<BrilligSolverStatus<F>, OpcodeResolutionError<F>> {
170        let status = self.vm.process_opcodes();
171        self.handle_vm_status(status)
172    }
173
174    pub fn step(&mut self) -> Result<BrilligSolverStatus<F>, OpcodeResolutionError<F>> {
175        let status = self.vm.process_opcode().clone();
176        self.handle_vm_status(status)
177    }
178
179    pub fn program_counter(&self) -> usize {
180        self.vm.program_counter()
181    }
182
183    /// Returns the status of the Brillig VM as a 'BrilligSolverStatus' resolution.
184    /// It may be finished, in-progress, failed, or may be waiting for results of a foreign call.
185    /// Return the "resolution" to the caller who may choose to make subsequent calls
186    /// (when it gets foreign call results for example).
187    fn handle_vm_status(
188        &self,
189        vm_status: VMStatus<F>,
190    ) -> Result<BrilligSolverStatus<F>, OpcodeResolutionError<F>> {
191        match vm_status {
192            VMStatus::Finished { .. } => Ok(BrilligSolverStatus::Finished),
193            VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress),
194            VMStatus::Failure { reason, call_stack } => {
195                let call_stack = call_stack
196                    .iter()
197                    .map(|brillig_index| OpcodeLocation::Brillig {
198                        acir_index: self.acir_index,
199                        brillig_index: *brillig_index,
200                    })
201                    .collect();
202                let payload = match reason {
203                    FailureReason::RuntimeError { message } => {
204                        Some(ResolvedAssertionPayload::String(message))
205                    }
206                    FailureReason::Trap { revert_data_offset, revert_data_size } => {
207                        extract_failure_payload_from_memory(
208                            self.vm.get_memory(),
209                            revert_data_offset
210                                .try_into()
211                                .expect("Failed conversion from u32 to usize"),
212                            revert_data_size
213                                .try_into()
214                                .expect("Failed conversion from u32 to usize"),
215                        )
216                    }
217                };
218
219                Err(OpcodeResolutionError::BrilligFunctionFailed {
220                    function_id: self.function_id,
221                    payload,
222                    call_stack,
223                })
224            }
225            VMStatus::ForeignCallWait { function, inputs } => {
226                Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs }))
227            }
228        }
229    }
230
231    pub(crate) fn finalize(
232        self,
233        witness: &mut WitnessMap<F>,
234        outputs: &[BrilligOutputs],
235    ) -> Result<(), OpcodeResolutionError<F>> {
236        assert!(!self.vm.is_profiling_active(), "Expected VM profiling to not be active");
237        self.finalize_inner(witness, outputs)
238    }
239
240    /// Finalize the VM and return the profiling samples.
241    pub(crate) fn finalize_with_profiling(
242        mut self,
243        witness: &mut WitnessMap<F>,
244        outputs: &[BrilligOutputs],
245    ) -> Result<BrilligProfilingSamples, OpcodeResolutionError<F>> {
246        assert!(self.vm.is_profiling_active(), "Expected VM profiling to be active");
247        self.finalize_inner(witness, outputs)?;
248        Ok(self.vm.take_profiling_samples())
249    }
250
251    /// Finalize the VM execution and write the outputs to the provided witness map.
252    fn finalize_inner(
253        &self,
254        witness: &mut WitnessMap<F>,
255        outputs: &[BrilligOutputs],
256    ) -> Result<(), OpcodeResolutionError<F>> {
257        // Finish the Brillig execution by writing the outputs to the witness map
258        let vm_status = self.vm.get_status();
259        match vm_status {
260            VMStatus::Finished { return_data_offset, return_data_size } => {
261                self.write_brillig_outputs(
262                    witness,
263                    return_data_offset.try_into().expect("Failed conversion from u32 to usize"),
264                    return_data_size.try_into().expect("Failed conversion from u32 to usize"),
265                    outputs,
266                )?;
267                Ok(())
268            }
269            _ => panic!("Brillig VM has not completed execution"),
270        }
271    }
272
273    /// Write VM execution results into the witness map
274    fn write_brillig_outputs(
275        &self,
276        witness_map: &mut WitnessMap<F>,
277        return_data_offset: usize,
278        return_data_size: usize,
279        outputs: &[BrilligOutputs],
280    ) -> Result<(), OpcodeResolutionError<F>> {
281        let memory = self.vm.get_memory();
282        let mut current_ret_data_idx = return_data_offset;
283        for output in outputs {
284            match output {
285                BrilligOutputs::Simple(witness) => {
286                    let value = memory
287                        .get(current_ret_data_idx)
288                        .expect("Return data index exceeds memory bounds");
289                    insert_value(witness, value.to_field(), witness_map)?;
290                    current_ret_data_idx =
291                        current_ret_data_idx.checked_add(1).expect("Return data index overflow");
292                }
293                BrilligOutputs::Array(witness_arr) => {
294                    for witness in witness_arr {
295                        let value = memory
296                            .get(current_ret_data_idx)
297                            .expect("Return data index exceeds memory bounds");
298                        insert_value(witness, value.to_field(), witness_map)?;
299                        current_ret_data_idx = current_ret_data_idx
300                            .checked_add(1)
301                            .expect("Return data index overflow");
302                    }
303                }
304            }
305        }
306
307        let expected_end = return_data_offset
308            .checked_add(return_data_size)
309            .expect("Return data offset and size overflow");
310        assert!(
311            current_ret_data_idx == expected_end,
312            "Brillig VM did not write the expected number of return values"
313        );
314        Ok(())
315    }
316
317    pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult<F>) {
318        match self.vm.get_status() {
319            VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result),
320            _ => unreachable!("Brillig VM is not waiting for a foreign call"),
321        }
322    }
323}
324
325/// Extracts a `ResolvedAssertionPayload` from a block of memory of a Brillig VM instance.
326///
327/// Returns `None` if the amount of memory requested is zero.
328fn extract_failure_payload_from_memory<F: AcirField>(
329    memory: &[MemoryValue<F>],
330    revert_data_offset: usize,
331    revert_data_size: usize,
332) -> Option<ResolvedAssertionPayload<F>> {
333    // Since noir can only revert with strings currently, we can parse return data as a string
334    if revert_data_size == 0 {
335        None
336    } else {
337        let end = revert_data_offset
338            .checked_add(revert_data_size)
339            .expect("Revert data offset and size overflow");
340        let mut revert_values_iter = memory
341            .get(revert_data_offset..end)
342            .expect("Revert data offset and size exceed memory bounds")
343            .iter();
344        let error_selector = ErrorSelector::new(
345            revert_values_iter
346                .next()
347                .copied()
348                .expect("Incorrect revert data size")
349                .try_into()
350                .expect("Error selector is not u64"),
351        );
352
353        Some(ResolvedAssertionPayload::Raw(RawAssertionPayload {
354            selector: error_selector,
355            data: revert_values_iter.map(|value| value.to_field()).collect(),
356        }))
357    }
358}
359
360/// Encapsulates a request from a Brillig VM process that encounters a [foreign call opcode][brillig_vm::brillig::Opcode::ForeignCall]
361/// where the result of the foreign call has not yet been provided.
362///
363/// The caller must resolve this opcode externally based upon the information in the request.
364#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
365pub struct ForeignCallWaitInfo<F> {
366    /// An identifier interpreted by the caller process
367    pub function: String,
368    /// Resolved inputs to a foreign call computed in the previous steps of a Brillig VM process
369    pub inputs: Vec<ForeignCallParam<F>>,
370}
371
372#[cfg(test)]
373mod tests {
374    use crate::pwg::BrilligSolver;
375    use acir::{
376        FieldElement,
377        brillig::{BinaryFieldOp, BitSize, HeapVector, IntegerBitSize, MemoryAddress, Opcode},
378        circuit::brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
379        native_types::{Expression, Witness, WitnessMap},
380    };
381    use std::collections::{BTreeMap, HashMap};
382
383    #[test]
384    fn test_solver() {
385        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
386            (Witness(1), FieldElement::from(1u128)),
387            (Witness(2), FieldElement::from(1u128)),
388            (Witness(3), FieldElement::from(2u128)),
389        ]));
390        let w1 = Expression::from(Witness(1));
391        let w2 = Expression::from(Witness(2));
392        let w3 = Expression::from(Witness(3));
393        let inputs =
394            vec![BrilligInputs::Single(w1), BrilligInputs::Single(w2), BrilligInputs::Single(w3)];
395
396        let backend = acvm_blackbox_solver::StubbedBlackBoxSolver;
397        let bytecode = vec![
398            Opcode::Const {
399                destination: MemoryAddress::Direct(21),
400                bit_size: BitSize::Integer(IntegerBitSize::U32),
401                value: FieldElement::from(1_u128),
402            },
403            Opcode::Const {
404                destination: MemoryAddress::Direct(20),
405                bit_size: BitSize::Integer(IntegerBitSize::U32),
406                value: FieldElement::from(0_u128),
407            },
408            Opcode::CalldataCopy {
409                destination_address: MemoryAddress::Direct(0),
410                size_address: MemoryAddress::Direct(21),
411                offset_address: MemoryAddress::Direct(20),
412            },
413            Opcode::Const {
414                destination: MemoryAddress::Direct(2),
415                bit_size: BitSize::Field,
416                value: FieldElement::from(0_u128),
417            },
418            Opcode::BinaryFieldOp {
419                destination: MemoryAddress::Direct(3),
420                op: BinaryFieldOp::Equals,
421                lhs: MemoryAddress::Direct(0),
422                rhs: MemoryAddress::Direct(2),
423            },
424            Opcode::JumpIf { condition: MemoryAddress::Direct(3), location: 8 },
425            Opcode::Const {
426                destination: MemoryAddress::Direct(1),
427                bit_size: BitSize::Field,
428                value: FieldElement::from(1_u128),
429            },
430            Opcode::BinaryFieldOp {
431                destination: MemoryAddress::Direct(0),
432                op: BinaryFieldOp::Add,
433                lhs: MemoryAddress::Direct(1),
434                rhs: MemoryAddress::Direct(0),
435            },
436            Opcode::Stop {
437                return_data: HeapVector {
438                    pointer: MemoryAddress::Direct(20),
439                    size: MemoryAddress::Direct(21),
440                },
441            },
442        ];
443        let mut solver = BrilligSolver::new_call(
444            &initial_witness,
445            &HashMap::new(),
446            &inputs,
447            &bytecode,
448            &backend,
449            0,
450            BrilligFunctionId::default(),
451            false,
452            None,
453        )
454        .unwrap();
455        solver.solve().unwrap();
456        let outputs = vec![BrilligOutputs::Simple(Witness(4))];
457        solver.finalize(&mut initial_witness, &outputs).unwrap();
458
459        assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
460    }
461}