1use acir::{
2 AcirField,
3 native_types::{Expression, Witness, WitnessMap},
4};
5
6use super::{ErrorLocation, OpcodeNotSolvable, OpcodeResolutionError, insert_value};
7
8pub(crate) struct ExpressionSolver;
11
12#[allow(clippy::enum_variant_names)]
13pub(super) enum OpcodeStatus<F> {
14 OpcodeSatisfied(F),
15 OpcodeSolvable(F, (F, Witness)),
16 OpcodeUnsolvable,
17}
18
19pub(crate) enum MulTerm<F> {
20 OneUnknown(F, Witness), TooManyUnknowns,
22 Solved(F),
23}
24
25impl ExpressionSolver {
26 pub(crate) fn solve<F: AcirField>(
38 initial_witness: &mut WitnessMap<F>,
39 opcode: &Expression<F>,
40 ) -> Result<(), OpcodeResolutionError<F>> {
41 let opcode = &ExpressionSolver::evaluate(opcode, initial_witness);
42
43 let mul_result = ExpressionSolver::solve_mul_term(&opcode.mul_terms, initial_witness);
45
46 let mul_result = if mul_result.is_err() {
49 let mul_terms = ExpressionSolver::combine_mul_terms(&opcode.mul_terms);
50 ExpressionSolver::solve_mul_term(&mul_terms, initial_witness)
51 } else {
52 mul_result
53 };
54
55 let mul_result = mul_result.map_err(|_| {
56 OpcodeResolutionError::OpcodeNotSolvable(
57 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
58 )
59 })?;
60
61 let opcode_status =
63 ExpressionSolver::solve_fan_in_term(&opcode.linear_combinations, initial_witness);
64
65 let opcode_status = if matches!(
68 (&mul_result, &opcode_status),
69 (MulTerm::Solved(..), OpcodeStatus::OpcodeUnsolvable)
70 ) {
71 let linear_combinations =
72 ExpressionSolver::combine_linear_terms(&opcode.linear_combinations);
73 ExpressionSolver::solve_fan_in_term(&linear_combinations, initial_witness)
74 } else {
75 opcode_status
76 };
77
78 match (mul_result, opcode_status) {
79 (MulTerm::TooManyUnknowns, _) | (_, OpcodeStatus::OpcodeUnsolvable) => {
80 Err(OpcodeResolutionError::OpcodeNotSolvable(
81 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
82 ))
83 }
84 (MulTerm::OneUnknown(q, w1), OpcodeStatus::OpcodeSolvable(a, (b, w2))) => {
85 if w1 == w2 {
86 let total_sum = a + opcode.q_c;
88 if (q + b).is_zero() {
89 if !total_sum.is_zero() {
90 Err(OpcodeResolutionError::UnsatisfiedConstrain {
91 opcode_location: ErrorLocation::Unresolved,
92 payload: None,
93 })
94 } else {
95 Ok(())
96 }
97 } else {
98 let assignment = -quick_invert(total_sum, q + b);
99 insert_value(&w1, assignment, initial_witness)
100 }
101 } else {
102 Err(OpcodeResolutionError::OpcodeNotSolvable(
104 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
105 ))
106 }
107 }
108 (
109 MulTerm::OneUnknown(partial_prod, unknown_var),
110 OpcodeStatus::OpcodeSatisfied(sum),
111 ) => {
112 let total_sum = sum + opcode.q_c;
117 if partial_prod.is_zero() {
118 if !total_sum.is_zero() {
119 Err(OpcodeResolutionError::UnsatisfiedConstrain {
120 opcode_location: ErrorLocation::Unresolved,
121 payload: None,
122 })
123 } else {
124 Ok(())
125 }
126 } else {
127 let assignment = -quick_invert(total_sum, partial_prod);
128 insert_value(&unknown_var, assignment, initial_witness)
129 }
130 }
131 (MulTerm::Solved(a), OpcodeStatus::OpcodeSatisfied(b)) => {
132 if !(a + b + opcode.q_c).is_zero() {
135 Err(OpcodeResolutionError::UnsatisfiedConstrain {
136 opcode_location: ErrorLocation::Unresolved,
137 payload: None,
138 })
139 } else {
140 Ok(())
141 }
142 }
143 (
144 MulTerm::Solved(total_prod),
145 OpcodeStatus::OpcodeSolvable(partial_sum, (coeff, unknown_var)),
146 ) => {
147 let total_sum = total_prod + partial_sum + opcode.q_c;
151 if coeff.is_zero() {
152 if !total_sum.is_zero() {
153 Err(OpcodeResolutionError::UnsatisfiedConstrain {
154 opcode_location: ErrorLocation::Unresolved,
155 payload: None,
156 })
157 } else {
158 Ok(())
159 }
160 } else {
161 let assignment = -quick_invert(total_sum, coeff);
162 insert_value(&unknown_var, assignment, initial_witness)
163 }
164 }
165 }
166 }
167
168 fn solve_mul_term<F: AcirField>(
174 mul_terms: &[(F, Witness, Witness)],
175 witness_assignments: &WitnessMap<F>,
176 ) -> Result<MulTerm<F>, OpcodeStatus<F>> {
177 match mul_terms.len() {
180 0 => Ok(MulTerm::Solved(F::zero())),
181 1 => Ok(ExpressionSolver::solve_mul_term_helper(&mul_terms[0], witness_assignments)),
182 _ => Err(OpcodeStatus::OpcodeUnsolvable),
183 }
184 }
185
186 fn solve_mul_term_helper<F: AcirField>(
192 term: &(F, Witness, Witness),
193 witness_assignments: &WitnessMap<F>,
194 ) -> MulTerm<F> {
195 let (q_m, w_l, w_r) = term;
196 let w_l_value = witness_assignments.get(w_l);
198 let w_r_value = witness_assignments.get(w_r);
199
200 match (w_l_value, w_r_value) {
201 (None, None) => MulTerm::TooManyUnknowns,
202 (Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
203 (None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
204 (Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
205 }
206 }
207
208 fn solve_fan_in_term_helper<F: AcirField>(
211 term: &(F, Witness),
212 witness_assignments: &WitnessMap<F>,
213 ) -> Option<F> {
214 let (q_l, w_l) = term;
215 let w_l_value = witness_assignments.get(w_l);
217 w_l_value.map(|a| *q_l * *a)
218 }
219
220 pub(super) fn solve_fan_in_term<F: AcirField>(
223 linear_combinations: &[(F, Witness)],
224 witness_assignments: &WitnessMap<F>,
225 ) -> OpcodeStatus<F> {
226 let mut unknown_variable = (F::zero(), Witness::default());
230 let mut num_unknowns = 0;
231 let mut result = F::zero();
233
234 for term in linear_combinations {
235 let value = ExpressionSolver::solve_fan_in_term_helper(term, witness_assignments);
236 match value {
237 Some(a) => result += a,
238 None => {
239 unknown_variable = *term;
240 num_unknowns += 1;
241 }
242 }
243
244 if num_unknowns > 1 {
246 return OpcodeStatus::OpcodeUnsolvable;
247 }
248 }
249
250 if num_unknowns == 0 {
251 return OpcodeStatus::OpcodeSatisfied(result);
252 }
253
254 OpcodeStatus::OpcodeSolvable(result, unknown_variable)
255 }
256
257 pub(crate) fn evaluate<F: AcirField>(
265 expr: &Expression<F>,
266 initial_witness: &WitnessMap<F>,
267 ) -> Expression<F> {
268 let mut result = Expression::default();
269 for &(c, w1, w2) in &expr.mul_terms {
270 let mul_result = ExpressionSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness);
271 match mul_result {
272 MulTerm::OneUnknown(v, w) => {
273 if !v.is_zero() {
274 result.linear_combinations.push((v, w));
275 }
276 }
277 MulTerm::TooManyUnknowns => {
278 if !c.is_zero() {
279 result.mul_terms.push((c, w1, w2));
280 }
281 }
282 MulTerm::Solved(f) => result.q_c += f,
283 }
284 }
285 for &(c, w) in &expr.linear_combinations {
286 if let Some(f) = ExpressionSolver::solve_fan_in_term_helper(&(c, w), initial_witness) {
287 result.q_c += f;
288 } else if !c.is_zero() {
289 result.linear_combinations.push((c, w));
290 }
291 }
292 result.q_c += expr.q_c;
293 result
294 }
295
296 pub(crate) fn combine_linear_terms<F: AcirField>(
299 linear_combinations: &[(F, Witness)],
300 ) -> Vec<(F, Witness)> {
301 let mut combined_linear_combinations = std::collections::HashMap::new();
302
303 for (c, w) in linear_combinations {
304 let existing_c = combined_linear_combinations.entry(*w).or_insert(F::zero());
305 *existing_c += *c;
306 }
307
308 combined_linear_combinations
309 .into_iter()
310 .filter_map(
311 |(witness, coeff)| {
312 if !coeff.is_zero() { Some((coeff, witness)) } else { None }
313 },
314 )
315 .collect()
316 }
317
318 pub(crate) fn combine_mul_terms<F: AcirField>(
322 mul_terms: &[(F, Witness, Witness)],
323 ) -> Vec<(F, Witness, Witness)> {
324 let mut hash_map = std::collections::HashMap::new();
327
328 for (scale, w_l, w_r) in mul_terms.iter().copied() {
330 let mut pair = [w_l, w_r];
331 pair.sort();
332
333 *hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale;
334 }
335
336 hash_map
337 .into_iter()
338 .filter(|(_, scale)| !scale.is_zero())
339 .map(|((w_l, w_r), scale)| (scale, w_l, w_r))
340 .collect()
341 }
342}
343
344fn quick_invert<F: AcirField>(numerator: F, denominator: F) -> F {
350 if denominator == F::one() {
351 numerator
352 } else if denominator == -F::one() {
353 -numerator
354 } else {
355 assert!(
356 denominator != F::zero(),
357 "quick_invert: attempting to divide numerator by F::zero()"
358 );
359 numerator / denominator
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use acir::FieldElement;
367
368 #[test]
369 fn quick_invert_matches_slow_invert() {
371 let numerator = FieldElement::from_be_bytes_reduce("hello_world".as_bytes());
372 assert_eq!(quick_invert(numerator, FieldElement::one()), numerator / FieldElement::one());
373 assert_eq!(quick_invert(numerator, -FieldElement::one()), numerator / -FieldElement::one());
374 }
375
376 #[test]
377 #[should_panic(expected = "quick_invert: attempting to divide numerator by F::zero()")]
378 fn quick_invert_zero_denominator() {
379 quick_invert(FieldElement::one(), FieldElement::zero());
380 }
381
382 #[test]
383 fn solves_simple_assignment() {
384 let a = Witness(0);
385
386 let opcode_a = Expression::from_str(&format!("{a} - 1")).unwrap();
388
389 let mut values = WitnessMap::new();
390 assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
391
392 assert_eq!(values.get(&a).unwrap(), &FieldElement::from(1_i128));
393 }
394
395 #[test]
396 fn solves_unknown_in_mul_term() {
397 let a = Witness(0);
398 let b = Witness(1);
399 let c = Witness(2);
400 let d = Witness(3);
401
402 let opcode_a = Expression::from_str(&format!("{a}*{b} - {b} - {c} - {d}")).unwrap();
404
405 let mut values = WitnessMap::new();
406 values.insert(b, FieldElement::from(2_i128));
407 values.insert(c, FieldElement::from(1_i128));
408 values.insert(d, FieldElement::from(1_i128));
409
410 assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
411
412 assert_eq!(values.get(&a).unwrap(), &FieldElement::from(2_i128));
413 }
414
415 #[test]
416 fn solves_unknown_in_linear_term() {
417 let a = Witness(0);
418 let b = Witness(1);
419 let c = Witness(2);
420 let d = Witness(3);
421
422 let opcode_a = Expression::from_str(&format!("{a} - {b} - {c} - {d}")).unwrap();
424
425 let e = Witness(4);
426 let opcode_b = Expression::from_str(&format!("{e} - {a} - {b}")).unwrap();
427
428 let mut values = WitnessMap::new();
429 values.insert(b, FieldElement::from(2_i128));
430 values.insert(c, FieldElement::from(1_i128));
431 values.insert(d, FieldElement::from(1_i128));
432
433 assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
434 assert_eq!(ExpressionSolver::solve(&mut values, &opcode_b), Ok(()));
435
436 assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128));
437 }
438
439 #[test]
440 fn solves_by_combining_linear_terms_after_they_have_been_multiplied_by_known_witnesses() {
441 let expr = Expression::from_str("w1 + w1*w0 - 4").unwrap();
442 let mut values = WitnessMap::new();
443 values.insert(Witness(0), FieldElement::from(1_i128));
444
445 let res = ExpressionSolver::solve(&mut values, &expr);
446 assert!(res.is_ok());
447
448 assert_eq!(values.get(&Witness(1)).unwrap(), &FieldElement::from(2_i128));
449 }
450
451 #[test]
452 fn solves_by_combining_mul_terms() {
453 let expr = Expression::from_str("w1*w2 - w2*w1 + w3 - 2").unwrap();
454 let mut values = WitnessMap::new();
455
456 let res = ExpressionSolver::solve(&mut values, &expr);
457 assert!(res.is_ok());
458
459 assert_eq!(values.get(&Witness(3)).unwrap(), &FieldElement::from(2_i128));
460 }
461}