acvm/compiler/optimizers/
general.rs1use acir::{
2 AcirField,
3 native_types::{Expression, Witness},
4};
5use indexmap::IndexMap;
6
7pub(crate) struct GeneralOptimizer;
13
14impl GeneralOptimizer {
15 pub(crate) fn optimize<F: AcirField>(opcode: Expression<F>) -> Expression<F> {
16 let opcode = simplify_mul_terms(opcode);
18 simplify_linear_terms(opcode)
19 }
20}
21
22fn simplify_mul_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
29 let mut hash_map: IndexMap<(Witness, Witness), F> = IndexMap::new();
30
31 for (scale, w_l, w_r) in gate.mul_terms {
33 let mut pair = [w_l, w_r];
34 pair.sort();
36
37 *hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale;
38 }
39
40 gate.mul_terms = hash_map
41 .into_iter()
42 .filter(|(_, scale)| !scale.is_zero())
43 .map(|((w_l, w_r), scale)| (scale, w_l, w_r))
44 .collect();
45 gate
46}
47
48fn simplify_linear_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
51 let mut hash_map: IndexMap<Witness, F> = IndexMap::new();
52
53 for (scale, witness) in gate.linear_combinations {
55 *hash_map.entry(witness).or_insert_with(F::zero) += scale;
56 }
57
58 gate.linear_combinations = hash_map
59 .into_iter()
60 .filter(|(_, scale)| !scale.is_zero())
61 .map(|(witness, scale)| (scale, witness))
62 .collect();
63 gate
64}
65
66#[cfg(test)]
67mod tests {
68 use acir::{
69 FieldElement,
70 circuit::{Circuit, Opcode},
71 };
72
73 use crate::{assert_circuit_snapshot, compiler::optimizers::GeneralOptimizer};
74
75 fn optimize(circuit: Circuit<FieldElement>) -> Circuit<FieldElement> {
76 let opcodes = circuit
77 .clone()
78 .opcodes
79 .into_iter()
80 .map(|opcode| {
81 if let Opcode::AssertZero(arith_expr) = opcode {
82 Opcode::AssertZero(GeneralOptimizer::optimize(arith_expr))
83 } else {
84 opcode
85 }
86 })
87 .collect();
88 let mut optimized_circuit = circuit;
89 optimized_circuit.opcodes = opcodes;
90 optimized_circuit
91 }
92
93 #[test]
94 fn removes_zero_coefficients_from_mul_terms() {
95 let src = "
96 private parameters: [w0, w1]
97 public parameters: []
98 return values: []
99
100 // The first multiplication should be removed
101 ASSERT 0*w0*w1 + w0*w1 = 0
102 ";
103 let circuit = Circuit::from_str(src).unwrap();
104 let optimized_circuit = optimize(circuit);
105 assert_circuit_snapshot!(optimized_circuit, @r"
106 private parameters: [w0, w1]
107 public parameters: []
108 return values: []
109 ASSERT 0 = w0*w1
110 ");
111 }
112
113 #[test]
114 fn removes_zero_coefficients_from_linear_terms() {
115 let src = "
116 private parameters: [w0, w1]
117 public parameters: []
118 return values: []
119
120 // The first linear combination should be removed
121 ASSERT 0*w0 + w1 = 0
122 ";
123 let circuit = Circuit::from_str(src).unwrap();
124 let optimized_circuit = optimize(circuit);
125 assert_circuit_snapshot!(optimized_circuit, @r"
126 private parameters: [w0, w1]
127 public parameters: []
128 return values: []
129 ASSERT w1 = 0
130 ");
131 }
132
133 #[test]
134 fn simplifies_mul_terms() {
135 let src = "
136 private parameters: [w0, w1]
137 public parameters: []
138 return values: []
139
140 // There are all mul terms with the same variables so we should end up with just one
141 // that is the sum of all the coefficients
142 ASSERT 2*w0*w1 + 3*w1*w0 + 4*w0*w1 = 0
143 ";
144 let circuit = Circuit::from_str(src).unwrap();
145 let optimized_circuit = optimize(circuit);
146 assert_circuit_snapshot!(optimized_circuit, @r"
147 private parameters: [w0, w1]
148 public parameters: []
149 return values: []
150 ASSERT 0 = 9*w0*w1
151 ");
152 }
153
154 #[test]
155 fn removes_zero_coefficients_after_simplifying_mul_terms() {
156 let src = "
157 private parameters: [w0, w1]
158 public parameters: []
159 return values: []
160 ASSERT 2*w0*w1 + 3*w1*w0 - 5*w0*w1 = 0
161 ";
162 let circuit = Circuit::from_str(src).unwrap();
163 let optimized_circuit = optimize(circuit);
164 assert_circuit_snapshot!(optimized_circuit, @r"
165 private parameters: [w0, w1]
166 public parameters: []
167 return values: []
168 ASSERT 0 = 0
169 ");
170 }
171
172 #[test]
173 fn simplifies_linear_terms() {
174 let src = "
175 private parameters: [w0, w1]
176 public parameters: []
177 return values: []
178
179 // These are all linear terms with the same variable so we should end up with just one
180 // that is the sum of all the coefficients
181 ASSERT w0 + 2*w0 + 3*w0 = 0
182 ";
183 let circuit = Circuit::from_str(src).unwrap();
184 let optimized_circuit = optimize(circuit);
185 assert_circuit_snapshot!(optimized_circuit, @r"
186 private parameters: [w0, w1]
187 public parameters: []
188 return values: []
189 ASSERT 0 = 6*w0
190 ");
191 }
192
193 #[test]
194 fn removes_zero_coefficients_after_simplifying_linear_terms() {
195 let src = "
196 private parameters: [w0, w1]
197 public parameters: []
198 return values: []
199 ASSERT w0 + 2*w0 - 3*w0 = 0
200 ";
201 let circuit = Circuit::from_str(src).unwrap();
202 let optimized_circuit = optimize(circuit);
203 assert_circuit_snapshot!(optimized_circuit, @r"
204 private parameters: [w0, w1]
205 public parameters: []
206 return values: []
207 ASSERT 0 = 0
208 ");
209 }
210
211 #[test]
212 fn simplify_mul_terms_example() {
213 let src = "
214 private parameters: [w0, w1]
215 public parameters: []
216 return values: []
217 ASSERT 0*w1*w1 + 2*w2*w1 - w2*w1 - w1*w2 = 0
218 ";
219 let circuit = Circuit::from_str(src).unwrap();
220 let optimized_circuit = optimize(circuit);
221 assert_circuit_snapshot!(optimized_circuit, @r"
222 private parameters: [w0, w1]
223 public parameters: []
224 return values: []
225 ASSERT 0 = 0
226 ");
227 }
228}