acvm/compiler/
simulator.rs

1use acir::{
2    AcirField,
3    circuit::{
4        Circuit, Opcode,
5        brillig::{BrilligInputs, BrilligOutputs},
6        opcodes::{BlockId, FunctionInput},
7    },
8    native_types::{Expression, Witness},
9};
10use std::collections::HashSet;
11
12use crate::pwg::arithmetic::ExpressionSolver;
13
14/// Simulate solving a circuit symbolically
15/// Instead of evaluating witness values from the inputs, like the PWG module is doing,
16/// this pass simply marks the witness that can be evaluated, from the known inputs,
17/// and incrementally from the previously marked witnesses.
18/// This avoids any computation on a big field which makes the process efficient.
19/// When all the witness of an opcode are marked as solvable, it means that the
20/// opcode is solvable.
21#[derive(Default)]
22pub struct CircuitSimulator {
23    /// Track the witnesses that can be solved
24    solvable_witnesses: HashSet<Witness>,
25
26    /// Track whether a [`BlockId`] has been initialized
27    initialized_blocks: HashSet<BlockId>,
28}
29
30impl CircuitSimulator {
31    /// Check whether the circuit is solvable in theory.
32    ///
33    /// # Returns
34    ///
35    /// Returns `None` if the circuit is deemed to be solvable
36    /// Otherwise returns `Some(index)` where `index` is the opcode index of the first unsolvable opcode.
37    pub fn check_circuit<F: AcirField>(circuit: &Circuit<F>) -> Option<usize> {
38        Self::default().run_check_circuit(circuit)
39    }
40
41    /// Simulate solving a circuit symbolically by keeping track of the witnesses that can be solved.
42    /// Returns the index of an opcode that cannot be solved, if any.
43    #[tracing::instrument(level = "trace", skip_all)]
44    fn run_check_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) -> Option<usize> {
45        let circuit_inputs = circuit.circuit_arguments();
46        self.solvable_witnesses.extend(circuit_inputs.iter());
47        for (i, op) in circuit.opcodes.iter().enumerate() {
48            if !self.try_solve(op) {
49                return Some(i);
50            }
51        }
52        None
53    }
54
55    /// Check if the Opcode can be solved, and if yes, add the solved witness to set of solvable witness
56    fn try_solve<F: AcirField>(&mut self, opcode: &Opcode<F>) -> bool {
57        match opcode {
58            Opcode::AssertZero(expr) => {
59                let mut unresolved = HashSet::new();
60                let combined_mul_terms = ExpressionSolver::combine_mul_terms(&expr.mul_terms);
61                let combined_linear_terms =
62                    ExpressionSolver::combine_linear_terms(&expr.linear_combinations);
63                for (_, w1, w2) in &combined_mul_terms {
64                    if !self.solvable_witnesses.contains(w1) {
65                        if !self.solvable_witnesses.contains(w2) {
66                            return false;
67                        }
68                        unresolved.insert(*w1);
69                    }
70                    if !self.solvable_witnesses.contains(w2) && w1 != w2 {
71                        unresolved.insert(*w2);
72                    }
73                }
74                for (_, w) in &combined_linear_terms {
75                    if !self.solvable_witnesses.contains(w) {
76                        unresolved.insert(*w);
77                    }
78                }
79                if unresolved.len() == 1 {
80                    self.mark_solvable(*unresolved.iter().next().unwrap());
81                    return true;
82                }
83                unresolved.is_empty()
84            }
85            Opcode::BlackBoxFuncCall(black_box_func_call) => {
86                let inputs = black_box_func_call.get_inputs_vec();
87                for input in inputs {
88                    if !self.can_solve_function_input(&input) {
89                        return false;
90                    }
91                }
92                let outputs = black_box_func_call.get_outputs_vec();
93                for output in outputs {
94                    self.mark_solvable(output);
95                }
96                true
97            }
98            Opcode::MemoryOp { block_id, op } => {
99                if !self.initialized_blocks.contains(block_id) {
100                    // Memory must be initialized before it can be used.
101                    return false;
102                }
103                if !self.can_solve_expression(&op.index) {
104                    return false;
105                }
106                if op.operation.is_zero() {
107                    let Some(w) = op.value.to_witness() else {
108                        return false;
109                    };
110                    self.mark_solvable(w);
111                    true
112                } else {
113                    self.can_solve_expression(&op.value)
114                }
115            }
116            Opcode::MemoryInit { block_id, init, .. } => {
117                for w in init {
118                    if !self.solvable_witnesses.contains(w) {
119                        return false;
120                    }
121                }
122                self.initialized_blocks.insert(*block_id)
123            }
124            Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
125                for input in inputs {
126                    if !self.can_solve_brillig_input(input) {
127                        return false;
128                    }
129                }
130                if !self.can_solve_expression(predicate) {
131                    return false;
132                }
133                for output in outputs {
134                    match output {
135                        BrilligOutputs::Simple(w) => self.mark_solvable(*w),
136                        BrilligOutputs::Array(arr) => {
137                            for w in arr {
138                                self.mark_solvable(*w);
139                            }
140                        }
141                    }
142                }
143                true
144            }
145            Opcode::Call { id: _, inputs, outputs, predicate } => {
146                for w in inputs {
147                    if !self.solvable_witnesses.contains(w) {
148                        return false;
149                    }
150                }
151                if !self.can_solve_expression(predicate) {
152                    return false;
153                }
154                for w in outputs {
155                    self.mark_solvable(*w);
156                }
157                true
158            }
159        }
160    }
161
162    /// Adds the witness to set of solvable witness
163    pub(crate) fn mark_solvable(&mut self, witness: Witness) {
164        self.solvable_witnesses.insert(witness);
165    }
166
167    pub fn can_solve_function_input<F: AcirField>(&self, input: &FunctionInput<F>) -> bool {
168        if let FunctionInput::Witness(w) = input {
169            return self.solvable_witnesses.contains(w);
170        }
171        true
172    }
173
174    fn can_solve_expression<F>(&self, expr: &Expression<F>) -> bool {
175        for w in Self::expr_witness(expr) {
176            if !self.solvable_witnesses.contains(&w) {
177                return false;
178            }
179        }
180        true
181    }
182
183    fn can_solve_brillig_input<F>(&self, input: &BrilligInputs<F>) -> bool {
184        match input {
185            BrilligInputs::Single(expr) => self.can_solve_expression(expr),
186            BrilligInputs::Array(exprs) => {
187                for expr in exprs {
188                    if !self.can_solve_expression(expr) {
189                        return false;
190                    }
191                }
192                true
193            }
194
195            BrilligInputs::MemoryArray(block_id) => self.initialized_blocks.contains(block_id),
196        }
197    }
198
199    pub(crate) fn expr_witness<F>(expr: &Expression<F>) -> impl Iterator<Item = Witness> {
200        expr.mul_terms
201            .iter()
202            .flat_map(|i| [i.1, i.2])
203            .chain(expr.linear_combinations.iter().map(|i| i.1))
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use crate::compiler::CircuitSimulator;
210    use acir::circuit::Circuit;
211
212    #[test]
213    fn reports_none_for_empty_circuit() {
214        let src = "
215        private parameters: []
216        public parameters: []
217        return values: []
218        ";
219        let empty_circuit = Circuit::from_str(src).unwrap();
220        assert!(CircuitSimulator::check_circuit(&empty_circuit).is_none());
221    }
222
223    #[test]
224    fn reports_none_for_connected_circuit() {
225        let src = "
226        private parameters: [w1]
227        public parameters: []
228        return values: []
229        ASSERT w2 = w1
230        ";
231        let connected_circuit = Circuit::from_str(src).unwrap();
232        assert!(CircuitSimulator::check_circuit(&connected_circuit).is_none());
233    }
234
235    #[test]
236    fn reports_true_for_connected_circuit_with_range() {
237        let src = "
238        private parameters: [w1, w3]
239        public parameters: []
240        return values: []
241        ASSERT w2 = w1
242        BLACKBOX::RANGE input: w3, bits: 8
243        ";
244        let connected_circuit = Circuit::from_str(src).unwrap();
245
246        assert!(CircuitSimulator::check_circuit(&connected_circuit).is_none());
247    }
248
249    #[test]
250    fn reports_false_for_disconnected_circuit() {
251        let src = "
252        private parameters: [w1]
253        public parameters: []
254        return values: []
255        ASSERT w2 = w1
256        ASSERT w4 = w3
257        ";
258        let disconnected_circuit = Circuit::from_str(src).unwrap();
259
260        assert!(CircuitSimulator::check_circuit(&disconnected_circuit).is_some());
261    }
262
263    #[test]
264    fn reports_none_for_blackbox_output() {
265        let src = "
266        private parameters: [w0, w1]
267        public parameters: []
268        return values: []
269        BLACKBOX::AND lhs: w0, rhs: w1, output: w2, bits: 32
270        ASSERT w3 = w2
271        ";
272        let circuit = Circuit::from_str(src).unwrap();
273        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
274    }
275
276    #[test]
277    fn reports_none_for_read_memory() {
278        let src = "
279        private parameters: [w0]
280        public parameters: []
281        return values: []
282        INIT b0 = [w0]
283        READ w1 = b0[0]
284        ASSERT w2 = w1
285        ";
286        let circuit = Circuit::from_str(src).unwrap();
287        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
288    }
289
290    #[test]
291    fn reports_none_for_call_output() {
292        let src = "
293        private parameters: [w0]
294        public parameters: []
295        return values: []
296        CALL func: 0, predicate: 1, inputs: [w0], outputs: [w1]
297        ASSERT w2 = w1
298        ";
299        let circuit = Circuit::from_str(src).unwrap();
300        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
301    }
302
303    #[test]
304    fn reports_none_for_brillig_call_output() {
305        let src = "
306        private parameters: [w0]
307        public parameters: []
308        return values: []
309        BRILLIG CALL func: 0, predicate: 1, inputs: [w0], outputs: [w1]
310        ASSERT w2 = w1
311        ";
312        let circuit = Circuit::from_str(src).unwrap();
313        assert!(CircuitSimulator::check_circuit(&circuit).is_none());
314    }
315
316    #[test]
317    fn reports_some_for_disconnected_circuit() {
318        let src = "
319        private parameters: [w1]
320        public parameters: []
321        return values: []
322        ASSERT w2 = w1
323        ASSERT w4 = w3
324        ";
325        let disconnected_circuit = Circuit::from_str(src).unwrap();
326        assert_eq!(CircuitSimulator::check_circuit(&disconnected_circuit), Some(1));
327    }
328
329    #[test]
330    fn reports_some_when_memory_block_is_passed_an_unknown_witness() {
331        let src = "
332        private parameters: [w1]
333        public parameters: []
334        return values: []
335        ASSERT w1 = 0
336        INIT b0 = [w0]
337        ";
338        let circuit = Circuit::from_str(src).unwrap();
339        assert_eq!(CircuitSimulator::check_circuit(&circuit), Some(1));
340    }
341
342    #[test]
343    fn reports_some_when_attempting_to_reinitialize_memory_block() {
344        let src = "
345        private parameters: [w0]
346        public parameters: []
347        return values: []
348        INIT b0 = [w0]
349        INIT b0 = [w0]
350        ";
351        let circuit = Circuit::from_str(src).unwrap();
352        assert_eq!(CircuitSimulator::check_circuit(&circuit), Some(1));
353    }
354
355    #[test]
356    fn reports_some_when_unknown_witness_is_multiplied_by_itself() {
357        // If an AssertZero contains just one unknown witness, it might still not possible
358        // to solve if: if that unknown witness is being multiplied by itself.
359        let src = "
360        private parameters: [w0]
361        public parameters: []
362        return values: []
363        ASSERT w0 = w1*w1
364        ";
365        let circuit = Circuit::from_str(src).unwrap();
366        assert_eq!(CircuitSimulator::check_circuit(&circuit), Some(0));
367    }
368
369    #[test]
370    fn reports_some_when_write_has_a_single_unknown_witness_in_its_value() {
371        let src = "
372        private parameters: [w0, w1]
373        public parameters: []
374        return values: []
375        INIT b0 = [w0]
376        WRITE b0[w0] = w1 + w2
377        ";
378        let circuit = Circuit::from_str(src).unwrap();
379        assert_eq!(CircuitSimulator::check_circuit(&circuit), Some(1));
380    }
381
382    #[test]
383    fn reports_none_when_write_has_known_witnesses_in_its_value() {
384        let src = "
385        private parameters: [w0, w1, w2]
386        public parameters: []
387        return values: []
388        INIT b0 = [w0]
389        WRITE b0[w0] = w1 + w2
390        ";
391        let circuit = Circuit::from_str(src).unwrap();
392        assert_eq!(CircuitSimulator::check_circuit(&circuit), None);
393    }
394
395    #[test]
396    fn reports_some_when_expression_can_simplify() {
397        let src = "
398        private parameters: []
399        public parameters: []
400        return values: []
401        ASSERT w1 = w1
402        ASSERT w2 = w1
403        ";
404        let empty_circuit = Circuit::from_str(src).unwrap();
405        assert_eq!(CircuitSimulator::check_circuit(&empty_circuit), Some(1));
406    }
407}