acvm/pwg/arithmetic.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
use acir::{
AcirField,
native_types::{Expression, Witness, WitnessMap},
};
use super::{ErrorLocation, OpcodeNotSolvable, OpcodeResolutionError, insert_value};
/// An Expression solver will take a Circuit's assert-zero opcodes with witness assignments
/// and create the other witness variables
pub(crate) struct ExpressionSolver;
#[allow(clippy::enum_variant_names)]
pub(super) enum OpcodeStatus<F> {
OpcodeSatisfied(F),
OpcodeSolvable(F, (F, Witness)),
OpcodeUnsolvable,
}
pub(crate) enum MulTerm<F> {
OneUnknown(F, Witness), // (qM * known_witness, unknown_witness)
TooManyUnknowns,
Solved(F),
}
impl ExpressionSolver {
/// Derives the rest of the witness in the provided expression based on the known witness values
/// 1. First we simplify the expression based on the known values and try to reduce the multiplication and linear terms
/// 2. If we end up with only the constant term;
/// - if it is 0 then the opcode is solved, if not,
/// - the assert_zero opcode is not satisfied and we return an error
/// 3. If we end up with only linear terms on the same witness 'w',
/// we can regroup them and solve 'a*w+c = 0':
/// - If 'a' is zero in the above expression;
/// - if c is also 0 then the opcode is solved
/// - if not that means the assert_zero opcode is not satisfied and we return an error
/// - If 'a' is not zero, we can solve it by setting the value of w: 'w = -c/a'
pub(crate) fn solve<F: AcirField>(
initial_witness: &mut WitnessMap<F>,
opcode: &Expression<F>,
) -> Result<(), OpcodeResolutionError<F>> {
let opcode = &ExpressionSolver::evaluate(opcode, initial_witness);
// Evaluate multiplication term
let mul_result =
ExpressionSolver::solve_mul_term(opcode, initial_witness).map_err(|_| {
OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
)
})?;
// Evaluate the fan-in terms
let opcode_status = ExpressionSolver::solve_fan_in_term(opcode, initial_witness);
match (mul_result, opcode_status) {
(MulTerm::TooManyUnknowns, _) | (_, OpcodeStatus::OpcodeUnsolvable) => {
Err(OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
))
}
(MulTerm::OneUnknown(q, w1), OpcodeStatus::OpcodeSolvable(a, (b, w2))) => {
if w1 == w2 {
// We have one unknown so we can solve the equation
let total_sum = a + opcode.q_c;
if (q + b).is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
} else {
let assignment = -quick_invert(total_sum, q + b);
insert_value(&w1, assignment, initial_witness)
}
} else {
// TODO: can we be more specific with this error?
Err(OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
))
}
}
(
MulTerm::OneUnknown(partial_prod, unknown_var),
OpcodeStatus::OpcodeSatisfied(sum),
) => {
// We have one unknown in the mul term and the fan-in terms are solved.
// Hence the equation is solvable, since there is a single unknown
// The equation is: partial_prod * unknown_var + sum + qC = 0
let total_sum = sum + opcode.q_c;
if partial_prod.is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
} else {
let assignment = -quick_invert(total_sum, partial_prod);
insert_value(&unknown_var, assignment, initial_witness)
}
}
(MulTerm::Solved(a), OpcodeStatus::OpcodeSatisfied(b)) => {
// All the variables in the MulTerm are solved and the Fan-in is also solved
// There is nothing to solve
if !(a + b + opcode.q_c).is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
}
(
MulTerm::Solved(total_prod),
OpcodeStatus::OpcodeSolvable(partial_sum, (coeff, unknown_var)),
) => {
// The variables in the MulTerm are solved nad there is one unknown in the Fan-in
// Hence the equation is solvable, since we have one unknown
// The equation is total_prod + partial_sum + coeff * unknown_var + q_C = 0
let total_sum = total_prod + partial_sum + opcode.q_c;
if coeff.is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
} else {
let assignment = -quick_invert(total_sum, coeff);
insert_value(&unknown_var, assignment, initial_witness)
}
}
}
}
/// Try to reduce the multiplication terms of the given expression to a known value or to a linear term,
/// using the provided witness mapping.
/// If there are 2 or more multiplication terms it returns the OpcodeUnsolvable error.
/// If no witnesses value is in the provided 'witness_assignments' map,
/// it returns MulTerm::TooManyUnknowns
fn solve_mul_term<F: AcirField>(
arith_opcode: &Expression<F>,
witness_assignments: &WitnessMap<F>,
) -> Result<MulTerm<F>, OpcodeStatus<F>> {
// First note that the mul term can only contain one/zero term,
// e.g. that it has been optimized, or else we're returning OpcodeUnsolvable
match arith_opcode.mul_terms.len() {
0 => Ok(MulTerm::Solved(F::zero())),
1 => Ok(ExpressionSolver::solve_mul_term_helper(
&arith_opcode.mul_terms[0],
witness_assignments,
)),
_ => Err(OpcodeStatus::OpcodeUnsolvable),
}
}
/// Try to solve a multiplication term of the form q*a*b, where
/// q is a constant and a,b are witnesses
/// If both a and b have known values (in the provided map), it returns the value q*a*b
/// If only one of a or b has a known value, it returns the linear term c*w where c is a constant and w is the unknown witness
/// If both a and b are unknown, it returns MulTerm::TooManyUnknowns
fn solve_mul_term_helper<F: AcirField>(
term: &(F, Witness, Witness),
witness_assignments: &WitnessMap<F>,
) -> MulTerm<F> {
let (q_m, w_l, w_r) = term;
// Check if these values are in the witness assignments
let w_l_value = witness_assignments.get(w_l);
let w_r_value = witness_assignments.get(w_r);
match (w_l_value, w_r_value) {
(None, None) => MulTerm::TooManyUnknowns,
(Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
(None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
(Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
}
}
/// Reduce a linear term to its value if the witness assignment is known
/// If the witness value is not known in the provided map, it returns None.
fn solve_fan_in_term_helper<F: AcirField>(
term: &(F, Witness),
witness_assignments: &WitnessMap<F>,
) -> Option<F> {
let (q_l, w_l) = term;
// Check if we have w_l
let w_l_value = witness_assignments.get(w_l);
w_l_value.map(|a| *q_l * *a)
}
/// Returns the summation of all of the variables, plus the unknown variable
/// Returns [`OpcodeStatus::OpcodeUnsolvable`], if there is more than one unknown variable
pub(super) fn solve_fan_in_term<F: AcirField>(
arith_opcode: &Expression<F>,
witness_assignments: &WitnessMap<F>,
) -> OpcodeStatus<F> {
// If the fan-in has more than 0 num_unknowns:
// This is the variable that we want to assign the value to
let mut unknown_variable = (F::zero(), Witness::default());
let mut num_unknowns = 0;
// This is the sum of all of the known variables
let mut result = F::zero();
for term in arith_opcode.linear_combinations.iter() {
let value = ExpressionSolver::solve_fan_in_term_helper(term, witness_assignments);
match value {
Some(a) => result += a,
None => {
unknown_variable = *term;
num_unknowns += 1;
}
}
// If we have more than 1 unknown, then we cannot solve this equation
if num_unknowns > 1 {
return OpcodeStatus::OpcodeUnsolvable;
}
}
if num_unknowns == 0 {
return OpcodeStatus::OpcodeSatisfied(result);
}
OpcodeStatus::OpcodeSolvable(result, unknown_variable)
}
// Partially evaluate the opcode using the known witnesses
// For instance if values of witness 'a' and 'b' are known, then
// the multiplication 'a*b' is removed and their multiplied values are added to the constant term
// If only witness 'a' is known, then the multiplication 'a*b' is replaced by the linear term '(value of b)*a'
// etc ...
// If all values are known, the partial evaluation gives a constant expression
// If no value is known, the partial evaluation returns the original expression
pub(crate) fn evaluate<F: AcirField>(
expr: &Expression<F>,
initial_witness: &WitnessMap<F>,
) -> Expression<F> {
let mut result = Expression::default();
for &(c, w1, w2) in &expr.mul_terms {
let mul_result = ExpressionSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness);
match mul_result {
MulTerm::OneUnknown(v, w) => {
if !v.is_zero() {
result.linear_combinations.push((v, w));
}
}
MulTerm::TooManyUnknowns => {
if !c.is_zero() {
result.mul_terms.push((c, w1, w2));
}
}
MulTerm::Solved(f) => result.q_c += f,
}
}
for &(c, w) in &expr.linear_combinations {
if let Some(f) = ExpressionSolver::solve_fan_in_term_helper(&(c, w), initial_witness) {
result.q_c += f;
} else if !c.is_zero() {
result.linear_combinations.push((c, w));
}
}
result.q_c += expr.q_c;
result
}
}
/// A wrapper around field division which skips the inversion if the denominator
/// is ±1.
///
/// Field inversion is the most significant cost of solving [`Opcode::AssertZero`][acir::circuit::opcodes::Opcode::AssertZero]
/// opcodes, which we can avoid when the denominator is ±1.
fn quick_invert<F: AcirField>(numerator: F, denominator: F) -> F {
if denominator == F::one() {
numerator
} else if denominator == -F::one() {
-numerator
} else {
numerator / denominator
}
}
#[cfg(test)]
mod tests {
use super::*;
use acir::FieldElement;
#[test]
/// Sanity check for the special cases of [`quick_invert`]
fn quick_invert_matches_slow_invert() {
let numerator = FieldElement::from_be_bytes_reduce("hello_world".as_bytes());
assert_eq!(quick_invert(numerator, FieldElement::one()), numerator / FieldElement::one());
assert_eq!(quick_invert(numerator, -FieldElement::one()), numerator / -FieldElement::one());
}
#[test]
fn solves_simple_assignment() {
let a = Witness(0);
// a - 1 == 0;
let opcode_a = Expression {
mul_terms: vec![],
linear_combinations: vec![(FieldElement::one(), a)],
q_c: -FieldElement::one(),
};
let mut values = WitnessMap::new();
assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
assert_eq!(values.get(&a).unwrap(), &FieldElement::from(1_i128));
}
#[test]
fn solves_unknown_in_mul_term() {
let a = Witness(0);
let b = Witness(1);
let c = Witness(2);
let d = Witness(3);
// a * b - b - c - d == 0;
let opcode_a = Expression {
mul_terms: vec![(FieldElement::one(), a, b)],
linear_combinations: vec![
(-FieldElement::one(), b),
(-FieldElement::one(), c),
(-FieldElement::one(), d),
],
q_c: FieldElement::zero(),
};
let mut values = WitnessMap::new();
values.insert(b, FieldElement::from(2_i128));
values.insert(c, FieldElement::from(1_i128));
values.insert(d, FieldElement::from(1_i128));
assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
assert_eq!(values.get(&a).unwrap(), &FieldElement::from(2_i128));
}
#[test]
fn solves_unknown_in_linear_term() {
let a = Witness(0);
let b = Witness(1);
let c = Witness(2);
let d = Witness(3);
// a = b + c + d;
let opcode_a = Expression {
mul_terms: vec![],
linear_combinations: vec![
(FieldElement::one(), a),
(-FieldElement::one(), b),
(-FieldElement::one(), c),
(-FieldElement::one(), d),
],
q_c: FieldElement::zero(),
};
let e = Witness(4);
let opcode_b = Expression {
mul_terms: vec![],
linear_combinations: vec![
(FieldElement::one(), e),
(-FieldElement::one(), a),
(-FieldElement::one(), b),
],
q_c: FieldElement::zero(),
};
let mut values = WitnessMap::new();
values.insert(b, FieldElement::from(2_i128));
values.insert(c, FieldElement::from(1_i128));
values.insert(d, FieldElement::from(1_i128));
assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
assert_eq!(ExpressionSolver::solve(&mut values, &opcode_b), Ok(()));
assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128));
}
}