acvm/compiler/optimizers/
general.rs

1use acir::{
2    AcirField,
3    native_types::{Expression, Witness},
4};
5use indexmap::IndexMap;
6
7/// The `GeneralOptimizer` processes all [`Expression`]s to:
8/// - remove any zero-coefficient terms.
9/// - merge any quadratic terms containing the same two witnesses.
10///
11/// This pass does not depend on any other pass and should be the first one in a set of optimizing passes.
12pub(crate) struct GeneralOptimizer;
13
14impl GeneralOptimizer {
15    pub(crate) fn optimize<F: AcirField>(opcode: Expression<F>) -> Expression<F> {
16        // TODO(https://github.com/noir-lang/noir/issues/10109): Perhaps this optimization can be done on the fly
17        let opcode = simplify_mul_terms(opcode);
18        simplify_linear_terms(opcode)
19    }
20}
21
22/// Simplifies all mul terms of the form `scale*w1*w2` with the same bi-variate variables
23/// while also removing terms that end up with a zero coefficient.
24///
25/// For instance, mul terms `0*w1*w1 + 2*w2*w1 - w2*w1 - w1*w2` will return an
26/// empty vector, because: w1*w2 and w2*w1 are the same bi-variate variable
27/// and the resulting scale is `2-1-1 = 0`
28fn simplify_mul_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
29    let mut hash_map: IndexMap<(Witness, Witness), F> = IndexMap::new();
30
31    // Canonicalize the ordering of the multiplication, lets just order by variable name
32    for (scale, w_l, w_r) in gate.mul_terms {
33        let mut pair = [w_l, w_r];
34        // Sort using rust sort algorithm
35        pair.sort();
36
37        *hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale;
38    }
39
40    gate.mul_terms = hash_map
41        .into_iter()
42        .filter(|(_, scale)| !scale.is_zero())
43        .map(|((w_l, w_r), scale)| (scale, w_l, w_r))
44        .collect();
45    gate
46}
47
48// Simplifies all linear terms with the same variables while also removing
49// terms that end up with a zero coefficient.
50fn simplify_linear_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
51    let mut hash_map: IndexMap<Witness, F> = IndexMap::new();
52
53    // Canonicalize the ordering of the terms, let's just order by variable name
54    for (scale, witness) in gate.linear_combinations {
55        *hash_map.entry(witness).or_insert_with(F::zero) += scale;
56    }
57
58    gate.linear_combinations = hash_map
59        .into_iter()
60        .filter(|(_, scale)| !scale.is_zero())
61        .map(|(witness, scale)| (scale, witness))
62        .collect();
63    gate
64}
65
66#[cfg(test)]
67mod tests {
68    use acir::{
69        FieldElement,
70        circuit::{Circuit, Opcode},
71    };
72
73    use crate::{assert_circuit_snapshot, compiler::optimizers::GeneralOptimizer};
74
75    fn optimize(circuit: Circuit<FieldElement>) -> Circuit<FieldElement> {
76        let opcodes = circuit
77            .clone()
78            .opcodes
79            .into_iter()
80            .map(|opcode| {
81                if let Opcode::AssertZero(arith_expr) = opcode {
82                    Opcode::AssertZero(GeneralOptimizer::optimize(arith_expr))
83                } else {
84                    opcode
85                }
86            })
87            .collect();
88        let mut optimized_circuit = circuit;
89        optimized_circuit.opcodes = opcodes;
90        optimized_circuit
91    }
92
93    #[test]
94    fn removes_zero_coefficients_from_mul_terms() {
95        let src = "
96        private parameters: [w0, w1]
97        public parameters: []
98        return values: []
99
100        // The first multiplication should be removed
101        ASSERT 0*w0*w1 + w0*w1 = 0
102        ";
103        let circuit = Circuit::from_str(src).unwrap();
104        let optimized_circuit = optimize(circuit);
105        assert_circuit_snapshot!(optimized_circuit, @r"
106        private parameters: [w0, w1]
107        public parameters: []
108        return values: []
109        ASSERT 0 = w0*w1
110        ");
111    }
112
113    #[test]
114    fn removes_zero_coefficients_from_linear_terms() {
115        let src = "
116        private parameters: [w0, w1]
117        public parameters: []
118        return values: []
119
120        // The first linear combination should be removed
121        ASSERT 0*w0 + w1 = 0
122        ";
123        let circuit = Circuit::from_str(src).unwrap();
124        let optimized_circuit = optimize(circuit);
125        assert_circuit_snapshot!(optimized_circuit, @r"
126        private parameters: [w0, w1]
127        public parameters: []
128        return values: []
129        ASSERT w1 = 0
130        ");
131    }
132
133    #[test]
134    fn simplifies_mul_terms() {
135        let src = "
136        private parameters: [w0, w1]
137        public parameters: []
138        return values: []
139
140        // There are all mul terms with the same variables so we should end up with just one
141        // that is the sum of all the coefficients
142        ASSERT 2*w0*w1 + 3*w1*w0 + 4*w0*w1 = 0
143        ";
144        let circuit = Circuit::from_str(src).unwrap();
145        let optimized_circuit = optimize(circuit);
146        assert_circuit_snapshot!(optimized_circuit, @r"
147        private parameters: [w0, w1]
148        public parameters: []
149        return values: []
150        ASSERT 0 = 9*w0*w1
151        ");
152    }
153
154    #[test]
155    fn removes_zero_coefficients_after_simplifying_mul_terms() {
156        let src = "
157        private parameters: [w0, w1]
158        public parameters: []
159        return values: []
160        ASSERT 2*w0*w1 + 3*w1*w0 - 5*w0*w1 = 0
161        ";
162        let circuit = Circuit::from_str(src).unwrap();
163        let optimized_circuit = optimize(circuit);
164        assert_circuit_snapshot!(optimized_circuit, @r"
165        private parameters: [w0, w1]
166        public parameters: []
167        return values: []
168        ASSERT 0 = 0
169        ");
170    }
171
172    #[test]
173    fn simplifies_linear_terms() {
174        let src = "
175        private parameters: [w0, w1]
176        public parameters: []
177        return values: []
178
179        // These are all linear terms with the same variable so we should end up with just one
180        // that is the sum of all the coefficients
181        ASSERT w0 + 2*w0 + 3*w0 = 0
182        ";
183        let circuit = Circuit::from_str(src).unwrap();
184        let optimized_circuit = optimize(circuit);
185        assert_circuit_snapshot!(optimized_circuit, @r"
186        private parameters: [w0, w1]
187        public parameters: []
188        return values: []
189        ASSERT 0 = 6*w0
190        ");
191    }
192
193    #[test]
194    fn removes_zero_coefficients_after_simplifying_linear_terms() {
195        let src = "
196        private parameters: [w0, w1]
197        public parameters: []
198        return values: []
199        ASSERT w0 + 2*w0 - 3*w0 = 0
200        ";
201        let circuit = Circuit::from_str(src).unwrap();
202        let optimized_circuit = optimize(circuit);
203        assert_circuit_snapshot!(optimized_circuit, @r"
204        private parameters: [w0, w1]
205        public parameters: []
206        return values: []
207        ASSERT 0 = 0
208        ");
209    }
210
211    #[test]
212    fn simplify_mul_terms_example() {
213        let src = "
214        private parameters: [w0, w1]
215        public parameters: []
216        return values: []
217        ASSERT 0*w1*w1 + 2*w2*w1 - w2*w1 - w1*w2 = 0
218        ";
219        let circuit = Circuit::from_str(src).unwrap();
220        let optimized_circuit = optimize(circuit);
221        assert_circuit_snapshot!(optimized_circuit, @r"
222        private parameters: [w0, w1]
223        public parameters: []
224        return values: []
225        ASSERT 0 = 0
226        ");
227    }
228}