acvm/compiler/optimizers/
general.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
use acir::{
    AcirField,
    native_types::{Expression, Witness},
};
use indexmap::IndexMap;

/// The `GeneralOptimizer` processes all [`Expression`]s to:
/// - remove any zero-coefficient terms.
/// - merge any quadratic terms containing the same two witnesses.
///
/// This pass does not depend on any other pass and should be the first one in a set of optimizing passes.
pub(crate) struct GeneralOptimizer;

impl GeneralOptimizer {
    pub(crate) fn optimize<F: AcirField>(opcode: Expression<F>) -> Expression<F> {
        // XXX: Perhaps this optimization can be done on the fly
        let opcode = simplify_mul_terms(opcode);
        simplify_linear_terms(opcode)
    }
}

/// Simplifies all mul terms of the form `scale*w1*w2` with the same bi-variate variables
/// while also removing terms that end up with a zero coefficient.
///
/// For instance, mul terms `0*w1*w1 + 2*w2*w1 - w2*w1 - w1*w2` will return an
/// empty vector, because: w1*w2 and w2*w1 are the same bi-variate variable
/// and the resulting scale is `2-1-1 = 0`
fn simplify_mul_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
    let mut hash_map: IndexMap<(Witness, Witness), F> = IndexMap::new();

    // Canonicalize the ordering of the multiplication, lets just order by variable name
    for (scale, w_l, w_r) in gate.mul_terms.into_iter() {
        let mut pair = [w_l, w_r];
        // Sort using rust sort algorithm
        pair.sort();

        *hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale;
    }

    gate.mul_terms = hash_map
        .into_iter()
        .filter(|(_, scale)| !scale.is_zero())
        .map(|((w_l, w_r), scale)| (scale, w_l, w_r))
        .collect();
    gate
}

// Simplifies all linear terms with the same variables while also removing
// terms that end up with a zero coefficient.
fn simplify_linear_terms<F: AcirField>(mut gate: Expression<F>) -> Expression<F> {
    let mut hash_map: IndexMap<Witness, F> = IndexMap::new();

    // Canonicalize the ordering of the terms, lets just order by variable name
    for (scale, witness) in gate.linear_combinations.into_iter() {
        *hash_map.entry(witness).or_insert_with(F::zero) += scale;
    }

    gate.linear_combinations = hash_map
        .into_iter()
        .filter(|(_, scale)| !scale.is_zero())
        .map(|(witness, scale)| (scale, witness))
        .collect();
    gate
}

#[cfg(test)]
mod tests {
    use acir::{
        FieldElement,
        circuit::{Circuit, Opcode},
    };

    use crate::{assert_circuit_snapshot, compiler::optimizers::GeneralOptimizer};

    fn optimize(circuit: Circuit<FieldElement>) -> Circuit<FieldElement> {
        let opcodes = circuit
            .clone()
            .opcodes
            .into_iter()
            .map(|opcode| {
                if let Opcode::AssertZero(arith_expr) = opcode {
                    Opcode::AssertZero(GeneralOptimizer::optimize(arith_expr))
                } else {
                    opcode
                }
            })
            .collect();
        let mut optimized_circuit = circuit;
        optimized_circuit.opcodes = opcodes;
        optimized_circuit
    }

    #[test]
    fn removes_zero_coefficients_from_mul_terms() {
        let src = "
        private parameters: [w0, w1]
        public parameters: []
        return values: []

        // The first multiplication should be removed
        ASSERT 0*w0*w1 + w0*w1 = 0
        ";
        let circuit = Circuit::from_str(src).unwrap();
        let optimized_circuit = optimize(circuit);
        assert_circuit_snapshot!(optimized_circuit, @r"
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT 0 = w0*w1
        ");
    }

    #[test]
    fn removes_zero_coefficients_from_linear_terms() {
        let src = "
        private parameters: [w0, w1]
        public parameters: []
        return values: []

        // The first linear combination should be removed
        ASSERT 0*w0 + w1 = 0
        ";
        let circuit = Circuit::from_str(src).unwrap();
        let optimized_circuit = optimize(circuit);
        assert_circuit_snapshot!(optimized_circuit, @r"
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT w1 = 0
        ");
    }

    #[test]
    fn simplifies_mul_terms() {
        let src = "
        private parameters: [w0, w1]
        public parameters: []
        return values: []

        // There are all mul terms with the same variables so we should end up with just one
        // that is the sum of all the coefficients
        ASSERT 2*w0*w1 + 3*w1*w0 + 4*w0*w1 = 0
        ";
        let circuit = Circuit::from_str(src).unwrap();
        let optimized_circuit = optimize(circuit);
        assert_circuit_snapshot!(optimized_circuit, @r"
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT 0 = 9*w0*w1
        ");
    }

    #[test]
    fn removes_zero_coefficients_after_simplifying_mul_terms() {
        let src = "
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT 2*w0*w1 + 3*w1*w0 - 5*w0*w1 = 0
        ";
        let circuit = Circuit::from_str(src).unwrap();
        let optimized_circuit = optimize(circuit);
        assert_circuit_snapshot!(optimized_circuit, @r"
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT 0 = 0
        ");
    }

    #[test]
    fn simplifies_linear_terms() {
        let src = "
        private parameters: [w0, w1]
        public parameters: []
        return values: []

        // These are all linear terms with the same variable so we should end up with just one
        // that is the sum of all the coefficients
        ASSERT w0 + 2*w0 + 3*w0 = 0
        ";
        let circuit = Circuit::from_str(src).unwrap();
        let optimized_circuit = optimize(circuit);
        assert_circuit_snapshot!(optimized_circuit, @r"
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT 0 = 6*w0
        ");
    }

    #[test]
    fn removes_zero_coefficients_after_simplifying_linear_terms() {
        let src = "
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT w0 + 2*w0 - 3*w0 = 0
        ";
        let circuit = Circuit::from_str(src).unwrap();
        let optimized_circuit = optimize(circuit);
        assert_circuit_snapshot!(optimized_circuit, @r"
        private parameters: [w0, w1]
        public parameters: []
        return values: []
        ASSERT 0 = 0
        ");
    }
}