acvm/pwg/
memory_op.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
use acir::{
    AcirField,
    circuit::opcodes::MemOp,
    native_types::{Witness, WitnessMap},
};

use super::{ErrorLocation, OpcodeResolutionError};
use super::{arithmetic::ExpressionSolver, get_value, insert_value, witness_to_value};

type MemoryIndex = u32;

/// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes.
pub(crate) struct MemoryOpSolver<F> {
    /// Known values of the memory block, based on the index
    /// This vec starts as big as it needs to, when initialized,
    /// then evolves as we process the opcodes.
    pub(super) block_value: Vec<F>,
}

impl<F: AcirField> MemoryOpSolver<F> {
    /// Creates a new MemoryOpSolver with the values given in `init`.
    pub(crate) fn new(
        init: &[Witness],
        initial_witness: &WitnessMap<F>,
    ) -> Result<Self, OpcodeResolutionError<F>> {
        Ok(Self {
            block_value: init
                .iter()
                .map(|witness| witness_to_value(initial_witness, *witness).copied())
                .collect::<Result<Vec<_>, _>>()?,
        })
    }

    fn len(&self) -> u32 {
        u32::try_from(self.block_value.len()).expect("expected a length that fits into a u32")
    }

    /// Convert a field element into a memory index
    /// Only 32 bits values are valid memory indices
    fn index_from_field(&self, index: F) -> Result<MemoryIndex, OpcodeResolutionError<F>> {
        index.try_to_u32().ok_or_else({
            || OpcodeResolutionError::IndexOutOfBounds {
                opcode_location: ErrorLocation::Unresolved,
                index,
                array_size: self.len(),
            }
        })
    }

    /// Update the 'block_value' map with the provided index/value
    /// Returns an 'IndexOutOfBounds' error if the index is outside the block range.
    fn write_memory_index(
        &mut self,
        index: MemoryIndex,
        value: F,
    ) -> Result<(), OpcodeResolutionError<F>> {
        if index >= self.len() {
            return Err(OpcodeResolutionError::IndexOutOfBounds {
                opcode_location: ErrorLocation::Unresolved,
                index: F::from(u128::from(index)),
                array_size: self.len(),
            });
        }

        self.block_value[index as usize] = value;
        Ok(())
    }

    /// Returns the value stored in the 'block_value' map for the provided index
    /// Returns an 'IndexOutOfBounds' error if the index is not in the map.
    fn read_memory_index(&self, index: MemoryIndex) -> Result<F, OpcodeResolutionError<F>> {
        self.block_value.get(index as usize).copied().ok_or(
            OpcodeResolutionError::IndexOutOfBounds {
                opcode_location: ErrorLocation::Unresolved,
                index: F::from(u128::from(index)),
                array_size: self.len(),
            },
        )
    }

    /// Update the 'block_values' by processing the provided Memory opcode
    /// The opcode 'op' contains the index and value of the operation and the type
    /// of the operation.
    /// They are all stored as an [acir::native_types::Expression]
    /// The type of 'operation' is '0' for a read and '1' for a write. It must be a constant
    /// expression.
    /// Index is not required to be constant but it must reduce to a known value
    /// for processing the opcode. This is done by doing the (partial) evaluation of its expression,
    /// using the provided witness map.
    ///
    /// READ: read the block at index op.index and update op.value with the read value
    /// - 'op.value' must reduce to a witness (after the evaluation of its expression)
    /// - the value is updated in the provided witness map, for the 'op.value' witness
    ///
    /// WRITE: update the block at index 'op.index' with 'op.value'
    /// - 'op.value' must reduce to a known value
    ///
    /// If a requirement is not met, it returns an error.
    pub(crate) fn solve_memory_op(
        &mut self,
        op: &MemOp<F>,
        initial_witness: &mut WitnessMap<F>,
        pedantic_solving: bool,
    ) -> Result<(), OpcodeResolutionError<F>> {
        let operation = get_value(&op.operation, initial_witness)?;

        // Find the memory index associated with this memory operation.
        let index = get_value(&op.index, initial_witness)?;
        let memory_index = self.index_from_field(index)?;

        // Calculate the value associated with this memory operation.
        //
        // In read operations, this corresponds to the witness index at which the value from memory will be written.
        // In write operations, this corresponds to the expression which will be written to memory.
        let value = ExpressionSolver::evaluate(&op.value, initial_witness);

        // `operation == 0` implies a read operation. (`operation == 1` implies write operation).
        let is_read_operation = operation.is_zero();
        if pedantic_solving {
            // We expect that the 'operation' should resolve to either 0 or 1.
            if !is_read_operation && !operation.is_one() {
                let opcode_location = ErrorLocation::Unresolved;
                return Err(OpcodeResolutionError::MemoryOperationLargerThanOne {
                    opcode_location,
                    operation,
                });
            }
        }

        if is_read_operation {
            // `value_read = arr[memory_index]`
            //
            // This is the value that we want to read into; i.e. copy from the memory block
            // into this value.
            let value_read_witness = value.to_witness().expect(
                "Memory must be read into a specified witness index, encountered an Expression",
            );

            let value_in_array = self.read_memory_index(memory_index)?;
            insert_value(&value_read_witness, value_in_array, initial_witness)
        } else {
            // `arr[memory_index] = value_write`
            //
            // This is the value that we want to write into; i.e. copy from `value_write`
            // into the memory block.
            let value_write = value;

            let value_to_write = get_value(&value_write, initial_witness)?;
            self.write_memory_index(memory_index, value_to_write)
        }
    }
}

#[cfg(test)]
mod tests {
    use std::collections::BTreeMap;

    use acir::{
        AcirField, FieldElement,
        circuit::opcodes::MemOp,
        native_types::{Witness, WitnessMap},
    };

    use super::MemoryOpSolver;

    #[test]
    fn test_solver() {
        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
            (Witness(1), FieldElement::from(1u128)),
            (Witness(2), FieldElement::from(1u128)),
            (Witness(3), FieldElement::from(2u128)),
        ]));

        let init = vec![Witness(1), Witness(2)];
        // Write the value '2' at index '1', and then read from index '1' into witness 4
        let trace = vec![
            MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
            MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
        ];

        let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();

        for op in trace {
            let pedantic_solving = true;
            block_solver.solve_memory_op(&op, &mut initial_witness, pedantic_solving).unwrap();
        }

        assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
    }

    #[test]
    fn test_index_out_of_bounds() {
        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
            (Witness(1), FieldElement::from(1u128)),
            (Witness(2), FieldElement::from(1u128)),
            (Witness(3), FieldElement::from(2u128)),
        ]));

        let init = vec![Witness(1), Witness(2)];
        // Write at index '1', and then read at index '2' on an array of size 2.
        let invalid_trace = vec![
            MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
            MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
        ];
        let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
        let mut err = None;
        for op in invalid_trace {
            if err.is_none() {
                let pedantic_solving = true;
                err =
                    block_solver.solve_memory_op(&op, &mut initial_witness, pedantic_solving).err();
            }
        }

        assert!(matches!(
            err,
            Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds {
                opcode_location: _,
                index,
                array_size: 2
            }) if index == FieldElement::from(2u128)
        ));
    }
}