acvm/pwg/
arithmetic.rs

1use acir::{
2    AcirField,
3    native_types::{Expression, Witness, WitnessMap},
4};
5
6use super::{ErrorLocation, OpcodeNotSolvable, OpcodeResolutionError, insert_value};
7
8/// An Expression solver will take a Circuit's assert-zero opcodes with witness assignments
9/// and create the other witness variables
10pub(crate) struct ExpressionSolver;
11
12#[allow(clippy::enum_variant_names)]
13pub(super) enum OpcodeStatus<F> {
14    OpcodeSatisfied(F),
15    OpcodeSolvable(F, (F, Witness)),
16    OpcodeUnsolvable,
17}
18
19pub(crate) enum MulTerm<F> {
20    OneUnknown(F, Witness), // (qM * known_witness, unknown_witness)
21    TooManyUnknowns,
22    Solved(F),
23}
24
25impl ExpressionSolver {
26    /// Derives the rest of the witness in the provided expression based on the known witness values
27    /// 1. First we simplify the expression based on the known values and try to reduce the multiplication and linear terms
28    /// 2. If we end up with only the constant term;
29    ///     - if it is 0 then the opcode is solved, if not,
30    ///     - the assert_zero opcode is not satisfied and we return an error
31    /// 3. If we end up with only linear terms on the same witness 'w',
32    ///    we can regroup them and solve 'a*w+c = 0':
33    ///    - If 'a' is zero in the above expression;
34    ///      - if c is also 0 then the opcode is solved
35    ///      - if not that means the assert_zero opcode is not satisfied and we return an error
36    ///    - If 'a' is not zero, we can solve it by setting the value of w: 'w = -c/a'
37    pub(crate) fn solve<F: AcirField>(
38        initial_witness: &mut WitnessMap<F>,
39        opcode: &Expression<F>,
40    ) -> Result<(), OpcodeResolutionError<F>> {
41        let opcode = &ExpressionSolver::evaluate(opcode, initial_witness);
42
43        // Evaluate multiplication terms
44        let mul_result = ExpressionSolver::solve_mul_term(&opcode.mul_terms, initial_witness);
45
46        // If we can't solve the multiplication terms, try again by combining multiplication terms
47        // with the same witnesses to see if they all cancel out.
48        let mul_result = if mul_result.is_err() {
49            let mul_terms = ExpressionSolver::combine_mul_terms(&opcode.mul_terms);
50            ExpressionSolver::solve_mul_term(&mul_terms, initial_witness)
51        } else {
52            mul_result
53        };
54
55        let mul_result = mul_result.map_err(|_| {
56            OpcodeResolutionError::OpcodeNotSolvable(
57                OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
58            )
59        })?;
60
61        // Evaluate the fan-in terms
62        let opcode_status =
63            ExpressionSolver::solve_fan_in_term(&opcode.linear_combinations, initial_witness);
64
65        // If we can solve the multiplication terms but not the linear terms,
66        // try again by combining linear terms with the same witness.
67        let opcode_status = if matches!(
68            (&mul_result, &opcode_status),
69            (MulTerm::Solved(..), OpcodeStatus::OpcodeUnsolvable)
70        ) {
71            let linear_combinations =
72                ExpressionSolver::combine_linear_terms(&opcode.linear_combinations);
73            ExpressionSolver::solve_fan_in_term(&linear_combinations, initial_witness)
74        } else {
75            opcode_status
76        };
77
78        match (mul_result, opcode_status) {
79            (MulTerm::TooManyUnknowns, _) | (_, OpcodeStatus::OpcodeUnsolvable) => {
80                Err(OpcodeResolutionError::OpcodeNotSolvable(
81                    OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
82                ))
83            }
84            (MulTerm::OneUnknown(q, w1), OpcodeStatus::OpcodeSolvable(a, (b, w2))) => {
85                if w1 == w2 {
86                    // We have one unknown so we can solve the equation
87                    let total_sum = a + opcode.q_c;
88                    if (q + b).is_zero() {
89                        if !total_sum.is_zero() {
90                            Err(OpcodeResolutionError::UnsatisfiedConstrain {
91                                opcode_location: ErrorLocation::Unresolved,
92                                payload: None,
93                            })
94                        } else {
95                            Ok(())
96                        }
97                    } else {
98                        let assignment = -quick_invert(total_sum, q + b);
99                        insert_value(&w1, assignment, initial_witness)
100                    }
101                } else {
102                    // TODO(https://github.com/noir-lang/noir/issues/10191): can we be more specific with this error?
103                    Err(OpcodeResolutionError::OpcodeNotSolvable(
104                        OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
105                    ))
106                }
107            }
108            (
109                MulTerm::OneUnknown(partial_prod, unknown_var),
110                OpcodeStatus::OpcodeSatisfied(sum),
111            ) => {
112                // We have one unknown in the mul term and the fan-in terms are solved.
113                // Hence the equation is solvable, since there is a single unknown
114                // The equation is: partial_prod * unknown_var + sum + qC = 0
115
116                let total_sum = sum + opcode.q_c;
117                if partial_prod.is_zero() {
118                    if !total_sum.is_zero() {
119                        Err(OpcodeResolutionError::UnsatisfiedConstrain {
120                            opcode_location: ErrorLocation::Unresolved,
121                            payload: None,
122                        })
123                    } else {
124                        Ok(())
125                    }
126                } else {
127                    let assignment = -quick_invert(total_sum, partial_prod);
128                    insert_value(&unknown_var, assignment, initial_witness)
129                }
130            }
131            (MulTerm::Solved(a), OpcodeStatus::OpcodeSatisfied(b)) => {
132                // All the variables in the MulTerm are solved and the Fan-in is also solved
133                // There is nothing to solve
134                if !(a + b + opcode.q_c).is_zero() {
135                    Err(OpcodeResolutionError::UnsatisfiedConstrain {
136                        opcode_location: ErrorLocation::Unresolved,
137                        payload: None,
138                    })
139                } else {
140                    Ok(())
141                }
142            }
143            (
144                MulTerm::Solved(total_prod),
145                OpcodeStatus::OpcodeSolvable(partial_sum, (coeff, unknown_var)),
146            ) => {
147                // The variables in the MulTerm are solved nad there is one unknown in the Fan-in
148                // Hence the equation is solvable, since we have one unknown
149                // The equation is total_prod + partial_sum + coeff * unknown_var + q_C = 0
150                let total_sum = total_prod + partial_sum + opcode.q_c;
151                if coeff.is_zero() {
152                    if !total_sum.is_zero() {
153                        Err(OpcodeResolutionError::UnsatisfiedConstrain {
154                            opcode_location: ErrorLocation::Unresolved,
155                            payload: None,
156                        })
157                    } else {
158                        Ok(())
159                    }
160                } else {
161                    let assignment = -quick_invert(total_sum, coeff);
162                    insert_value(&unknown_var, assignment, initial_witness)
163                }
164            }
165        }
166    }
167
168    /// Try to reduce the multiplication terms of the given expression's mul terms to a known value or to a linear term,
169    /// using the provided witness mapping.
170    /// If there are 2 or more multiplication terms it returns the OpcodeUnsolvable error.
171    /// If no witnesses value is in the provided 'witness_assignments' map,
172    /// it returns MulTerm::TooManyUnknowns
173    fn solve_mul_term<F: AcirField>(
174        mul_terms: &[(F, Witness, Witness)],
175        witness_assignments: &WitnessMap<F>,
176    ) -> Result<MulTerm<F>, OpcodeStatus<F>> {
177        // First note that the mul term can only contain one/zero term,
178        // e.g. that it has been optimized, or else we're returning OpcodeUnsolvable
179        match mul_terms.len() {
180            0 => Ok(MulTerm::Solved(F::zero())),
181            1 => Ok(ExpressionSolver::solve_mul_term_helper(&mul_terms[0], witness_assignments)),
182            _ => Err(OpcodeStatus::OpcodeUnsolvable),
183        }
184    }
185
186    /// Try to solve a multiplication term of the form q*a*b, where
187    /// q is a constant and a,b are witnesses
188    /// If both a and b have known values (in the provided map), it returns the value q*a*b
189    /// If only one of a or b has a known value, it returns the linear term c*w where c is a constant and w is the unknown witness
190    /// If both a and b are unknown, it returns MulTerm::TooManyUnknowns
191    fn solve_mul_term_helper<F: AcirField>(
192        term: &(F, Witness, Witness),
193        witness_assignments: &WitnessMap<F>,
194    ) -> MulTerm<F> {
195        let (q_m, w_l, w_r) = term;
196        // Check if these values are in the witness assignments
197        let w_l_value = witness_assignments.get(w_l);
198        let w_r_value = witness_assignments.get(w_r);
199
200        match (w_l_value, w_r_value) {
201            (None, None) => MulTerm::TooManyUnknowns,
202            (Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
203            (None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
204            (Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
205        }
206    }
207
208    /// Reduce a linear term to its value if the witness assignment is known
209    /// If the witness value is not known in the provided map, it returns None.
210    fn solve_fan_in_term_helper<F: AcirField>(
211        term: &(F, Witness),
212        witness_assignments: &WitnessMap<F>,
213    ) -> Option<F> {
214        let (q_l, w_l) = term;
215        // Check if we have w_l
216        let w_l_value = witness_assignments.get(w_l);
217        w_l_value.map(|a| *q_l * *a)
218    }
219
220    /// Returns the summation of all of the variables, plus the unknown variable
221    /// Returns [`OpcodeStatus::OpcodeUnsolvable`], if there is more than one unknown variable
222    pub(super) fn solve_fan_in_term<F: AcirField>(
223        linear_combinations: &[(F, Witness)],
224        witness_assignments: &WitnessMap<F>,
225    ) -> OpcodeStatus<F> {
226        // If the fan-in has more than 0 num_unknowns:
227
228        // This is the variable that we want to assign the value to
229        let mut unknown_variable = (F::zero(), Witness::default());
230        let mut num_unknowns = 0;
231        // This is the sum of all of the known variables
232        let mut result = F::zero();
233
234        for term in linear_combinations {
235            let value = ExpressionSolver::solve_fan_in_term_helper(term, witness_assignments);
236            match value {
237                Some(a) => result += a,
238                None => {
239                    unknown_variable = *term;
240                    num_unknowns += 1;
241                }
242            }
243
244            // If we have more than 1 unknown, then we cannot solve this equation
245            if num_unknowns > 1 {
246                return OpcodeStatus::OpcodeUnsolvable;
247            }
248        }
249
250        if num_unknowns == 0 {
251            return OpcodeStatus::OpcodeSatisfied(result);
252        }
253
254        OpcodeStatus::OpcodeSolvable(result, unknown_variable)
255    }
256
257    // Partially evaluate the opcode using the known witnesses
258    // For instance if values of witness 'a' and 'b' are known, then
259    // the multiplication 'a*b' is removed and their multiplied values are added to the constant term
260    // If only witness 'a' is known, then the multiplication 'a*b' is replaced by the linear term '(value of b)*a'
261    // etc ...
262    // If all values are known, the partial evaluation gives a constant expression
263    // If no value is known, the partial evaluation returns the original expression
264    pub(crate) fn evaluate<F: AcirField>(
265        expr: &Expression<F>,
266        initial_witness: &WitnessMap<F>,
267    ) -> Expression<F> {
268        let mut result = Expression::default();
269        for &(c, w1, w2) in &expr.mul_terms {
270            let mul_result = ExpressionSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness);
271            match mul_result {
272                MulTerm::OneUnknown(v, w) => {
273                    if !v.is_zero() {
274                        result.linear_combinations.push((v, w));
275                    }
276                }
277                MulTerm::TooManyUnknowns => {
278                    if !c.is_zero() {
279                        result.mul_terms.push((c, w1, w2));
280                    }
281                }
282                MulTerm::Solved(f) => result.q_c += f,
283            }
284        }
285        for &(c, w) in &expr.linear_combinations {
286            if let Some(f) = ExpressionSolver::solve_fan_in_term_helper(&(c, w), initial_witness) {
287                result.q_c += f;
288            } else if !c.is_zero() {
289                result.linear_combinations.push((c, w));
290            }
291        }
292        result.q_c += expr.q_c;
293        result
294    }
295
296    /// Combines linear terms with the same witness by summing their coefficients.
297    /// For example `w1 + 2*w1` becomes `3*w1`.
298    pub(crate) fn combine_linear_terms<F: AcirField>(
299        linear_combinations: &[(F, Witness)],
300    ) -> Vec<(F, Witness)> {
301        let mut combined_linear_combinations = std::collections::HashMap::new();
302
303        for (c, w) in linear_combinations {
304            let existing_c = combined_linear_combinations.entry(*w).or_insert(F::zero());
305            *existing_c += *c;
306        }
307
308        combined_linear_combinations
309            .into_iter()
310            .filter_map(
311                |(witness, coeff)| {
312                    if !coeff.is_zero() { Some((coeff, witness)) } else { None }
313                },
314            )
315            .collect()
316    }
317
318    /// Combines multiplication terms with the same witnesses by summing their coefficients.
319    /// For example `w1*w2 + 2*w2*w1` becomes `3*w1*w2`. If a coefficient ends up being zero,
320    /// the term is removed.
321    pub(crate) fn combine_mul_terms<F: AcirField>(
322        mul_terms: &[(F, Witness, Witness)],
323    ) -> Vec<(F, Witness, Witness)> {
324        // This is similar to GeneralOptimizer::simplify_mul_terms but it's duplicated because
325        // we don't have access to the acvm crate here.
326        let mut hash_map = std::collections::HashMap::new();
327
328        // Canonicalize the ordering of the multiplication, lets just order by variable name
329        for (scale, w_l, w_r) in mul_terms.iter().copied() {
330            let mut pair = [w_l, w_r];
331            pair.sort();
332
333            *hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale;
334        }
335
336        hash_map
337            .into_iter()
338            .filter(|(_, scale)| !scale.is_zero())
339            .map(|((w_l, w_r), scale)| (scale, w_l, w_r))
340            .collect()
341    }
342}
343
344/// A wrapper around field division which skips the inversion if the denominator
345/// is ±1.
346///
347/// Field inversion is the most significant cost of solving [`Opcode::AssertZero`][acir::circuit::opcodes::Opcode::AssertZero]
348/// opcodes, which we can avoid when the denominator is ±1.
349fn quick_invert<F: AcirField>(numerator: F, denominator: F) -> F {
350    if denominator == F::one() {
351        numerator
352    } else if denominator == -F::one() {
353        -numerator
354    } else {
355        assert!(
356            denominator != F::zero(),
357            "quick_invert: attempting to divide numerator by F::zero()"
358        );
359        numerator / denominator
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use acir::FieldElement;
367
368    #[test]
369    /// Sanity check for the special cases of [`quick_invert`]
370    fn quick_invert_matches_slow_invert() {
371        let numerator = FieldElement::from_be_bytes_reduce("hello_world".as_bytes());
372        assert_eq!(quick_invert(numerator, FieldElement::one()), numerator / FieldElement::one());
373        assert_eq!(quick_invert(numerator, -FieldElement::one()), numerator / -FieldElement::one());
374    }
375
376    #[test]
377    #[should_panic(expected = "quick_invert: attempting to divide numerator by F::zero()")]
378    fn quick_invert_zero_denominator() {
379        quick_invert(FieldElement::one(), FieldElement::zero());
380    }
381
382    #[test]
383    fn solves_simple_assignment() {
384        let a = Witness(0);
385
386        // a - 1 == 0;
387        let opcode_a = Expression::from_str(&format!("{a} - 1")).unwrap();
388
389        let mut values = WitnessMap::new();
390        assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
391
392        assert_eq!(values.get(&a).unwrap(), &FieldElement::from(1_i128));
393    }
394
395    #[test]
396    fn solves_unknown_in_mul_term() {
397        let a = Witness(0);
398        let b = Witness(1);
399        let c = Witness(2);
400        let d = Witness(3);
401
402        // a * b - b - c - d == 0;
403        let opcode_a = Expression::from_str(&format!("{a}*{b} - {b} - {c} - {d}")).unwrap();
404
405        let mut values = WitnessMap::new();
406        values.insert(b, FieldElement::from(2_i128));
407        values.insert(c, FieldElement::from(1_i128));
408        values.insert(d, FieldElement::from(1_i128));
409
410        assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
411
412        assert_eq!(values.get(&a).unwrap(), &FieldElement::from(2_i128));
413    }
414
415    #[test]
416    fn solves_unknown_in_linear_term() {
417        let a = Witness(0);
418        let b = Witness(1);
419        let c = Witness(2);
420        let d = Witness(3);
421
422        // a = b + c + d;
423        let opcode_a = Expression::from_str(&format!("{a} - {b} - {c} - {d}")).unwrap();
424
425        let e = Witness(4);
426        let opcode_b = Expression::from_str(&format!("{e} - {a} - {b}")).unwrap();
427
428        let mut values = WitnessMap::new();
429        values.insert(b, FieldElement::from(2_i128));
430        values.insert(c, FieldElement::from(1_i128));
431        values.insert(d, FieldElement::from(1_i128));
432
433        assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
434        assert_eq!(ExpressionSolver::solve(&mut values, &opcode_b), Ok(()));
435
436        assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128));
437    }
438
439    #[test]
440    fn solves_by_combining_linear_terms_after_they_have_been_multiplied_by_known_witnesses() {
441        let expr = Expression::from_str("w1 + w1*w0 - 4").unwrap();
442        let mut values = WitnessMap::new();
443        values.insert(Witness(0), FieldElement::from(1_i128));
444
445        let res = ExpressionSolver::solve(&mut values, &expr);
446        assert!(res.is_ok());
447
448        assert_eq!(values.get(&Witness(1)).unwrap(), &FieldElement::from(2_i128));
449    }
450
451    #[test]
452    fn solves_by_combining_mul_terms() {
453        let expr = Expression::from_str("w1*w2 - w2*w1 + w3 - 2").unwrap();
454        let mut values = WitnessMap::new();
455
456        let res = ExpressionSolver::solve(&mut values, &expr);
457        assert!(res.is_ok());
458
459        assert_eq!(values.get(&Witness(3)).unwrap(), &FieldElement::from(2_i128));
460    }
461}