acvm/pwg/memory_op.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
use acir::{
AcirField,
circuit::opcodes::MemOp,
native_types::{Witness, WitnessMap},
};
use super::{ErrorLocation, OpcodeResolutionError};
use super::{arithmetic::ExpressionSolver, get_value, insert_value, witness_to_value};
type MemoryIndex = u32;
/// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes.
pub(crate) struct MemoryOpSolver<F> {
/// Known values of the memory block, based on the index
/// This vec starts as big as it needs to, when initialized,
/// then evolves as we process the opcodes.
pub(super) block_value: Vec<F>,
}
impl<F: AcirField> MemoryOpSolver<F> {
/// Creates a new MemoryOpSolver with the values given in `init`.
pub(crate) fn new(
init: &[Witness],
initial_witness: &WitnessMap<F>,
) -> Result<Self, OpcodeResolutionError<F>> {
Ok(Self {
block_value: init
.iter()
.map(|witness| witness_to_value(initial_witness, *witness).copied())
.collect::<Result<Vec<_>, _>>()?,
})
}
fn len(&self) -> u32 {
u32::try_from(self.block_value.len()).expect("expected a length that fits into a u32")
}
/// Convert a field element into a memory index
/// Only 32 bits values are valid memory indices
fn index_from_field(&self, index: F) -> Result<MemoryIndex, OpcodeResolutionError<F>> {
index.try_to_u32().ok_or_else({
|| OpcodeResolutionError::IndexOutOfBounds {
opcode_location: ErrorLocation::Unresolved,
index,
array_size: self.len(),
}
})
}
/// Update the 'block_value' map with the provided index/value
/// Returns an 'IndexOutOfBounds' error if the index is outside the block range.
fn write_memory_index(
&mut self,
index: MemoryIndex,
value: F,
) -> Result<(), OpcodeResolutionError<F>> {
if index >= self.len() {
return Err(OpcodeResolutionError::IndexOutOfBounds {
opcode_location: ErrorLocation::Unresolved,
index: F::from(u128::from(index)),
array_size: self.len(),
});
}
self.block_value[index as usize] = value;
Ok(())
}
/// Returns the value stored in the 'block_value' map for the provided index
/// Returns an 'IndexOutOfBounds' error if the index is not in the map.
fn read_memory_index(&self, index: MemoryIndex) -> Result<F, OpcodeResolutionError<F>> {
self.block_value.get(index as usize).copied().ok_or(
OpcodeResolutionError::IndexOutOfBounds {
opcode_location: ErrorLocation::Unresolved,
index: F::from(u128::from(index)),
array_size: self.len(),
},
)
}
/// Update the 'block_values' by processing the provided Memory opcode
/// The opcode 'op' contains the index and value of the operation and the type
/// of the operation.
/// They are all stored as an [acir::native_types::Expression]
/// The type of 'operation' is '0' for a read and '1' for a write. It must be a constant
/// expression.
/// Index is not required to be constant but it must reduce to a known value
/// for processing the opcode. This is done by doing the (partial) evaluation of its expression,
/// using the provided witness map.
///
/// READ: read the block at index op.index and update op.value with the read value
/// - 'op.value' must reduce to a witness (after the evaluation of its expression)
/// - the value is updated in the provided witness map, for the 'op.value' witness
///
/// WRITE: update the block at index 'op.index' with 'op.value'
/// - 'op.value' must reduce to a known value
///
/// If a requirement is not met, it returns an error.
pub(crate) fn solve_memory_op(
&mut self,
op: &MemOp<F>,
initial_witness: &mut WitnessMap<F>,
pedantic_solving: bool,
) -> Result<(), OpcodeResolutionError<F>> {
let operation = get_value(&op.operation, initial_witness)?;
// Find the memory index associated with this memory operation.
let index = get_value(&op.index, initial_witness)?;
let memory_index = self.index_from_field(index)?;
// Calculate the value associated with this memory operation.
//
// In read operations, this corresponds to the witness index at which the value from memory will be written.
// In write operations, this corresponds to the expression which will be written to memory.
let value = ExpressionSolver::evaluate(&op.value, initial_witness);
// `operation == 0` implies a read operation. (`operation == 1` implies write operation).
let is_read_operation = operation.is_zero();
if pedantic_solving {
// We expect that the 'operation' should resolve to either 0 or 1.
if !is_read_operation && !operation.is_one() {
let opcode_location = ErrorLocation::Unresolved;
return Err(OpcodeResolutionError::MemoryOperationLargerThanOne {
opcode_location,
operation,
});
}
}
if is_read_operation {
// `value_read = arr[memory_index]`
//
// This is the value that we want to read into; i.e. copy from the memory block
// into this value.
let value_read_witness = value.to_witness().expect(
"Memory must be read into a specified witness index, encountered an Expression",
);
let value_in_array = self.read_memory_index(memory_index)?;
insert_value(&value_read_witness, value_in_array, initial_witness)
} else {
// `arr[memory_index] = value_write`
//
// This is the value that we want to write into; i.e. copy from `value_write`
// into the memory block.
let value_write = value;
let value_to_write = get_value(&value_write, initial_witness)?;
self.write_memory_index(memory_index, value_to_write)
}
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use acir::{
AcirField, FieldElement,
circuit::opcodes::MemOp,
native_types::{Witness, WitnessMap},
};
use super::MemoryOpSolver;
#[test]
fn test_solver() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
(Witness(2), FieldElement::from(1u128)),
(Witness(3), FieldElement::from(2u128)),
]));
let init = vec![Witness(1), Witness(2)];
// Write the value '2' at index '1', and then read from index '1' into witness 4
let trace = vec![
MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
];
let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
for op in trace {
let pedantic_solving = true;
block_solver.solve_memory_op(&op, &mut initial_witness, pedantic_solving).unwrap();
}
assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
}
#[test]
fn test_index_out_of_bounds() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
(Witness(2), FieldElement::from(1u128)),
(Witness(3), FieldElement::from(2u128)),
]));
let init = vec![Witness(1), Witness(2)];
// Write at index '1', and then read at index '2' on an array of size 2.
let invalid_trace = vec![
MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
];
let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
let pedantic_solving = true;
err =
block_solver.solve_memory_op(&op, &mut initial_witness, pedantic_solving).err();
}
}
assert!(matches!(
err,
Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds {
opcode_location: _,
index,
array_size: 2
}) if index == FieldElement::from(2u128)
));
}
}