1use acir::{
2 AcirField,
3 circuit::opcodes::MemOp,
4 native_types::{Witness, WitnessMap},
5};
6
7use super::{ErrorLocation, OpcodeResolutionError};
8use super::{arithmetic::ExpressionSolver, get_value, insert_value, witness_to_value};
9
10type MemoryIndex = u32;
11
12pub(crate) struct MemoryOpSolver<F> {
14 pub(super) block_value: Vec<F>,
18}
19
20impl<F: AcirField> MemoryOpSolver<F> {
21 pub(crate) fn new(
23 init: &[Witness],
24 initial_witness: &WitnessMap<F>,
25 ) -> Result<Self, OpcodeResolutionError<F>> {
26 Ok(Self {
27 block_value: init
28 .iter()
29 .map(|witness| witness_to_value(initial_witness, *witness).copied())
30 .collect::<Result<Vec<_>, _>>()?,
31 })
32 }
33
34 pub(crate) fn len(&self) -> u32 {
35 u32::try_from(self.block_value.len()).expect("expected a length that fits into a u32")
36 }
37
38 pub(crate) fn index_from_field(
41 &self,
42 index: F,
43 ) -> Result<MemoryIndex, OpcodeResolutionError<F>> {
44 index.try_to_u32().ok_or_else({
45 || OpcodeResolutionError::IndexOutOfBounds {
46 opcode_location: ErrorLocation::Unresolved,
47 index,
48 array_size: self.len(),
49 }
50 })
51 }
52
53 pub(crate) fn write_memory_index(
56 &mut self,
57 index: MemoryIndex,
58 value: F,
59 ) -> Result<(), OpcodeResolutionError<F>> {
60 if index >= self.len() {
61 return Err(OpcodeResolutionError::IndexOutOfBounds {
62 opcode_location: ErrorLocation::Unresolved,
63 index: F::from(u128::from(index)),
64 array_size: self.len(),
65 });
66 }
67
68 self.block_value[index as usize] = value;
69 Ok(())
70 }
71
72 pub(crate) fn read_memory_index(
75 &self,
76 index: MemoryIndex,
77 ) -> Result<F, OpcodeResolutionError<F>> {
78 self.block_value.get(index as usize).copied().ok_or(
79 OpcodeResolutionError::IndexOutOfBounds {
80 opcode_location: ErrorLocation::Unresolved,
81 index: F::from(u128::from(index)),
82 array_size: self.len(),
83 },
84 )
85 }
86
87 pub(crate) fn solve_memory_op(
106 &mut self,
107 op: &MemOp<F>,
108 initial_witness: &mut WitnessMap<F>,
109 ) -> Result<(), OpcodeResolutionError<F>> {
110 let operation = get_value(&op.operation, initial_witness)?;
111
112 let index = get_value(&op.index, initial_witness)?;
114 let memory_index = self.index_from_field(index)?;
115
116 let value = ExpressionSolver::evaluate(&op.value, initial_witness);
121
122 let is_read_operation = operation.is_zero();
124 if !is_read_operation && !operation.is_one() {
126 let opcode_location = ErrorLocation::Unresolved;
127 return Err(OpcodeResolutionError::MemoryOperationLargerThanOne {
128 opcode_location,
129 operation,
130 });
131 }
132
133 if is_read_operation {
134 let value_read_witness = value.to_witness().expect(
139 "Memory must be read into a specified witness index, encountered an Expression",
140 );
141
142 let value_in_array = self.read_memory_index(memory_index)?;
143 insert_value(&value_read_witness, value_in_array, initial_witness)
144 } else {
145 let value_write = value;
150
151 let value_to_write = get_value(&value_write, initial_witness)?;
152 self.write_memory_index(memory_index, value_to_write)
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use std::collections::BTreeMap;
160
161 use acir::{
162 AcirField, FieldElement,
163 circuit::opcodes::MemOp,
164 native_types::{Witness, WitnessMap},
165 };
166
167 use super::MemoryOpSolver;
168
169 #[test]
170 fn test_solver() {
171 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
172 (Witness(1), FieldElement::from(1u128)),
173 (Witness(2), FieldElement::from(1u128)),
174 (Witness(3), FieldElement::from(2u128)),
175 ]));
176
177 let init = vec![Witness(1), Witness(2)];
178 let trace = vec![
180 MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
181 MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
182 ];
183
184 let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
185
186 for op in trace {
187 block_solver.solve_memory_op(&op, &mut initial_witness).unwrap();
188 }
189
190 assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
191 }
192
193 #[test]
194 fn test_index_out_of_bounds() {
195 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
196 (Witness(1), FieldElement::from(1u128)),
197 (Witness(2), FieldElement::from(1u128)),
198 (Witness(3), FieldElement::from(2u128)),
199 ]));
200
201 let init = vec![Witness(1), Witness(2)];
202 let invalid_trace = vec![
204 MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
205 MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
206 ];
207 let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
208 let mut err = None;
209 for op in invalid_trace {
210 if err.is_none() {
211 err = block_solver.solve_memory_op(&op, &mut initial_witness).err();
212 }
213 }
214
215 assert!(matches!(
216 err,
217 Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds {
218 opcode_location: _,
219 index,
220 array_size: 2
221 }) if index == FieldElement::from(2u128)
222 ));
223 }
224}