acir/native_types/
witness_stack.rs1use std::io::Read;
2
3use acir_field::AcirField;
4use flate2::Compression;
5use flate2::bufread::GzDecoder;
6use flate2::bufread::GzEncoder;
7use msgpack_tagged::MsgpackTagged;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11use crate::SerializationFormat;
12use crate::serialization;
13
14use super::WitnessMap;
15
16#[derive(Debug, Error)]
17enum SerializationError {
18 #[error("error compressing witness stack: {0}")]
19 Compress(std::io::Error),
20
21 #[error("error decompressing witness stack: {0}")]
22 Decompress(std::io::Error),
23
24 #[error("error serializing witness stack: {0}")]
25 Serialize(std::io::Error),
26
27 #[error("error deserializing witness stack: {0}")]
28 Deserialize(std::io::Error),
29}
30
31#[derive(Debug, Error)]
33#[error(transparent)]
34pub struct WitnessStackError(#[from] SerializationError);
35
36#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38#[derive(Serialize, Deserialize, MsgpackTagged)]
39#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
40pub struct WitnessStack<F> {
41 #[tag(0)]
42 stack: Vec<StackItem<F>>,
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
46#[derive(Serialize, Deserialize, MsgpackTagged)]
47#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
48pub struct StackItem<F> {
49 #[tag(0)]
51 pub index: u32,
52 #[tag(1)]
54 pub witness: WitnessMap<F>,
55}
56
57impl<F> WitnessStack<F> {
58 pub fn push(&mut self, index: u32, witness: WitnessMap<F>) {
60 self.stack.push(StackItem { index, witness });
61 }
62
63 pub fn pop(&mut self) -> Option<StackItem<F>> {
65 self.stack.pop()
66 }
67
68 pub fn peek(&self) -> Option<&StackItem<F>> {
70 self.stack.last()
71 }
72
73 pub fn length(&self) -> usize {
75 self.stack.len()
76 }
77}
78
79impl<F: AcirField + Serialize + MsgpackTagged> WitnessStack<F> {
80 pub fn serialize(&self) -> Result<Vec<u8>, WitnessStackError> {
82 let format = SerializationFormat::from_env()
83 .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
84 self.serialize_with_format(format.unwrap_or_default())
85 }
86
87 pub fn serialize_with_format(
89 &self,
90 format: SerializationFormat,
91 ) -> Result<Vec<u8>, WitnessStackError> {
92 let buf = serialization::serialize_with_format(self, format)
93 .map_err(|e| WitnessStackError(SerializationError::Serialize(e)))?;
94
95 let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
96 let mut buf = Vec::new();
97 deflater
98 .read_to_end(&mut buf)
99 .map_err(|e| WitnessStackError(SerializationError::Compress(e)))?;
100
101 Ok(buf)
102 }
103}
104
105impl<F: AcirField + for<'a> Deserialize<'a> + MsgpackTagged> WitnessStack<F> {
106 pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessStackError> {
108 let mut deflater = GzDecoder::new(buf);
109 let mut buf = Vec::new();
110 deflater
111 .read_to_end(&mut buf)
112 .map_err(|e| WitnessStackError(SerializationError::Decompress(e)))?;
113
114 serialization::deserialize_any_format(&buf)
115 .map_err(|e| WitnessStackError(SerializationError::Deserialize(e)))
116 }
117}
118
119impl<F> From<WitnessMap<F>> for WitnessStack<F> {
120 fn from(witness: WitnessMap<F>) -> Self {
121 let stack = vec![StackItem { index: 0, witness }];
122 Self { stack }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use crate::native_types::Witness;
130 use acir_field::FieldElement;
131
132 #[test]
133 fn test_round_trip_serialization() {
134 let mut stack = WitnessStack::default();
136
137 let mut witness1 = WitnessMap::new();
139 witness1.insert(Witness(0), FieldElement::from(42u128));
140 witness1.insert(Witness(1), FieldElement::from(123u128));
141 stack.push(0, witness1);
142
143 let mut witness2 = WitnessMap::new();
145 witness2.insert(Witness(0), FieldElement::from(999u128));
146 witness2.insert(Witness(5), FieldElement::zero());
147 stack.push(1, witness2);
148
149 let mut witness3 = WitnessMap::new();
151 witness3.insert(Witness(10), FieldElement::one());
152 witness3.insert(Witness(20), FieldElement::from(u128::MAX));
153 stack.push(2, witness3);
154
155 let serialized = stack.serialize().expect("Serialization should succeed");
157
158 let deserialized =
160 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
161
162 assert_eq!(stack, deserialized);
164 }
165
166 #[test]
167 fn test_round_trip_empty_witness_stack() {
168 let original = WitnessStack::<FieldElement>::default();
170
171 let serialized = original.serialize().expect("Serialization should succeed");
172 let deserialized =
173 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
174
175 assert_eq!(original, deserialized);
176 }
177
178 #[test]
179 fn test_round_trip_single_stack_item() {
180 let mut stack = WitnessStack::default();
182 let mut witness = WitnessMap::new();
183 witness.insert(Witness(0), FieldElement::from(12345u128));
184 witness.insert(Witness(1), FieldElement::from(67890u128));
185 stack.push(0, witness);
186
187 let serialized = stack.serialize().expect("Serialization should succeed");
188 let deserialized =
189 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
190
191 assert_eq!(stack, deserialized);
192 }
193
194 #[test]
195 fn test_round_trip_from_witness_map() {
196 let mut witness = WitnessMap::new();
198 witness.insert(Witness(0), FieldElement::from(111u128));
199 witness.insert(Witness(1), FieldElement::from(222u128));
200 witness.insert(Witness(2), FieldElement::from(333u128));
201
202 let original = WitnessStack::from(witness);
203
204 let serialized = original.serialize().expect("Serialization should succeed");
205 let deserialized =
206 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
207
208 assert_eq!(original, deserialized);
209 }
210
211 #[test]
212 fn test_round_trip_large_stack() {
213 let mut stack = WitnessStack::default();
215
216 for i in 0..10 {
217 let mut witness = WitnessMap::new();
218 witness.insert(Witness(i), FieldElement::from(u128::from(i) * 100));
219 witness.insert(Witness(i + 100), FieldElement::from(u128::from(i) * 1000));
220 stack.push(i, witness);
221 }
222
223 let serialized = stack.serialize().expect("Serialization should succeed");
224 let deserialized =
225 WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
226
227 assert_eq!(stack, deserialized);
228 }
229
230 #[test]
231 fn test_stack_operations() {
232 let mut stack = WitnessStack::default();
234
235 let mut witness1 = WitnessMap::new();
236 witness1.insert(Witness(0), FieldElement::from(1u128));
237 stack.push(0, witness1.clone());
238
239 let mut witness2 = WitnessMap::new();
240 witness2.insert(Witness(1), FieldElement::from(2u128));
241 stack.push(1, witness2.clone());
242
243 assert_eq!(stack.length(), 2);
244 assert_eq!(stack.peek().unwrap().index, 1);
245
246 let popped = stack.pop().unwrap();
247 assert_eq!(popped.index, 1);
248 assert_eq!(popped.witness, witness2);
249
250 assert_eq!(stack.length(), 1);
251 assert_eq!(stack.peek().unwrap().index, 0);
252 assert_eq!(stack.length(), 1);
253 }
254}