acir/native_types/
witness_stack.rs1use std::io::Read;
2
3use acir_field::AcirField;
4use flate2::Compression;
5use flate2::bufread::GzDecoder;
6use flate2::bufread::GzEncoder;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10use crate::SerializationFormat;
11use crate::serialization;
12
13use super::WitnessMap;
14
15#[derive(Debug, Error)]
16enum SerializationError {
17 #[error("error compressing witness stack: {0}")]
18 Compress(std::io::Error),
19
20 #[error("error decompressing witness stack: {0}")]
21 Decompress(std::io::Error),
22
23 #[error("error serializing witness stack: {0}")]
24 Serialize(std::io::Error),
25
26 #[error("error deserializing witness stack: {0}")]
27 Deserialize(std::io::Error),
28}
29
30#[derive(Debug, Error)]
32#[error(transparent)]
33pub struct WitnessStackError(#[from] SerializationError);
34
35#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
37#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
38pub struct WitnessStack<F> {
39 stack: Vec<StackItem<F>>,
40}
41
42#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
43#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
44pub struct StackItem<F> {
45 pub index: u32,
47 pub witness: WitnessMap<F>,
49}
50
51impl<F> WitnessStack<F> {
52 pub fn push(&mut self, index: u32, witness: WitnessMap<F>) {
54 self.stack.push(StackItem { index, witness });
55 }
56
57 pub fn pop(&mut self) -> Option<StackItem<F>> {
59 self.stack.pop()
60 }
61
62 pub fn peek(&self) -> Option<&StackItem<F>> {
64 self.stack.last()
65 }
66
67 pub fn length(&self) -> usize {
69 self.stack.len()
70 }
71}
72
73impl<F: AcirField + Serialize> WitnessStack<F> {
74 pub fn serialize(&self) -> Result<Vec<u8>, WitnessStackError> {
76 let format = SerializationFormat::from_env()
77 .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
78 self.serialize_with_format(format.unwrap_or_default())
79 }
80
81 pub fn serialize_with_format(
83 &self,
84 format: SerializationFormat,
85 ) -> Result<Vec<u8>, WitnessStackError> {
86 let buf = serialization::serialize_with_format(self, format)
87 .map_err(|e| WitnessStackError(SerializationError::Serialize(e)))?;
88
89 let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
90 let mut buf = Vec::new();
91 deflater
92 .read_to_end(&mut buf)
93 .map_err(|e| WitnessStackError(SerializationError::Compress(e)))?;
94
95 Ok(buf)
96 }
97}
98
99impl<F: AcirField + for<'a> Deserialize<'a>> WitnessStack<F> {
100 pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessStackError> {
102 let mut deflater = GzDecoder::new(buf);
103 let mut buf = Vec::new();
104 deflater
105 .read_to_end(&mut buf)
106 .map_err(|e| WitnessStackError(SerializationError::Decompress(e)))?;
107
108 serialization::deserialize_any_format(&buf)
109 .map_err(|e| WitnessStackError(SerializationError::Deserialize(e)))
110 }
111}
112
113impl<F> From<WitnessMap<F>> for WitnessStack<F> {
114 fn from(witness: WitnessMap<F>) -> Self {
115 let stack = vec![StackItem { index: 0, witness }];
116 Self { stack }
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::native_types::Witness;
124 use acir_field::FieldElement;
125
126 #[test]
127 fn test_round_trip_serialization() {
128 let mut stack = WitnessStack::default();
130
131 let mut witness1 = WitnessMap::new();
133 witness1.insert(Witness(0), FieldElement::from(42u128));
134 witness1.insert(Witness(1), FieldElement::from(123u128));
135 stack.push(0, witness1);
136
137 let mut witness2 = WitnessMap::new();
139 witness2.insert(Witness(0), FieldElement::from(999u128));
140 witness2.insert(Witness(5), FieldElement::zero());
141 stack.push(1, witness2);
142
143 let mut witness3 = WitnessMap::new();
145 witness3.insert(Witness(10), FieldElement::one());
146 witness3.insert(Witness(20), FieldElement::from(u128::MAX));
147 stack.push(2, witness3);
148
149 let serialized = stack.serialize().expect("Serialization should succeed");
151
152 let deserialized =
154 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
155
156 assert_eq!(stack, deserialized);
158 }
159
160 #[test]
161 fn test_round_trip_empty_witness_stack() {
162 let original = WitnessStack::<FieldElement>::default();
164
165 let serialized = original.serialize().expect("Serialization should succeed");
166 let deserialized =
167 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
168
169 assert_eq!(original, deserialized);
170 }
171
172 #[test]
173 fn test_round_trip_single_stack_item() {
174 let mut stack = WitnessStack::default();
176 let mut witness = WitnessMap::new();
177 witness.insert(Witness(0), FieldElement::from(12345u128));
178 witness.insert(Witness(1), FieldElement::from(67890u128));
179 stack.push(0, witness);
180
181 let serialized = stack.serialize().expect("Serialization should succeed");
182 let deserialized =
183 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
184
185 assert_eq!(stack, deserialized);
186 }
187
188 #[test]
189 fn test_round_trip_from_witness_map() {
190 let mut witness = WitnessMap::new();
192 witness.insert(Witness(0), FieldElement::from(111u128));
193 witness.insert(Witness(1), FieldElement::from(222u128));
194 witness.insert(Witness(2), FieldElement::from(333u128));
195
196 let original = WitnessStack::from(witness);
197
198 let serialized = original.serialize().expect("Serialization should succeed");
199 let deserialized =
200 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
201
202 assert_eq!(original, deserialized);
203 }
204
205 #[test]
206 fn test_round_trip_large_stack() {
207 let mut stack = WitnessStack::default();
209
210 for i in 0..10 {
211 let mut witness = WitnessMap::new();
212 witness.insert(Witness(i), FieldElement::from(u128::from(i) * 100));
213 witness.insert(Witness(i + 100), FieldElement::from(u128::from(i) * 1000));
214 stack.push(i, witness);
215 }
216
217 let serialized = stack.serialize().expect("Serialization should succeed");
218 let deserialized =
219 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
220
221 assert_eq!(stack, deserialized);
222 }
223
224 #[test]
225 fn test_stack_operations() {
226 let mut stack = WitnessStack::default();
228
229 let mut witness1 = WitnessMap::new();
230 witness1.insert(Witness(0), FieldElement::from(1u128));
231 stack.push(0, witness1.clone());
232
233 let mut witness2 = WitnessMap::new();
234 witness2.insert(Witness(1), FieldElement::from(2u128));
235 stack.push(1, witness2.clone());
236
237 assert_eq!(stack.length(), 2);
238 assert_eq!(stack.peek().unwrap().index, 1);
239
240 let popped = stack.pop().unwrap();
241 assert_eq!(popped.index, 1);
242 assert_eq!(popped.witness, witness2);
243
244 assert_eq!(stack.length(), 1);
245 assert_eq!(stack.peek().unwrap().index, 0);
246 assert_eq!(stack.length(), 1);
247 }
248}