acir/native_types/
witness_stack.rs

1use std::io::Read;
2
3use acir_field::AcirField;
4use flate2::Compression;
5use flate2::bufread::GzDecoder;
6use flate2::bufread::GzEncoder;
7use msgpack_tagged::MsgpackTagged;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11use crate::SerializationFormat;
12use crate::serialization;
13
14use super::WitnessMap;
15
16#[derive(Debug, Error)]
17enum SerializationError {
18    #[error("error compressing witness stack: {0}")]
19    Compress(std::io::Error),
20
21    #[error("error decompressing witness stack: {0}")]
22    Decompress(std::io::Error),
23
24    #[error("error serializing witness stack: {0}")]
25    Serialize(std::io::Error),
26
27    #[error("error deserializing witness stack: {0}")]
28    Deserialize(std::io::Error),
29}
30
31/// Native error for serializing/deserializing a witness stack.
32#[derive(Debug, Error)]
33#[error(transparent)]
34pub struct WitnessStackError(#[from] SerializationError);
35
36/// An ordered set of witness maps for separate circuits
37#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38#[derive(Serialize, Deserialize, MsgpackTagged)]
39#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
40pub struct WitnessStack<F> {
41    #[tag(0)]
42    stack: Vec<StackItem<F>>,
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
46#[derive(Serialize, Deserialize, MsgpackTagged)]
47#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
48pub struct StackItem<F> {
49    /// Index into a [crate::circuit::Program] function list for which we have an associated witness
50    #[tag(0)]
51    pub index: u32,
52    /// A full witness for the respective constraint system specified by the index
53    #[tag(1)]
54    pub witness: WitnessMap<F>,
55}
56
57impl<F> WitnessStack<F> {
58    /// Append an element to the top of the stack
59    pub fn push(&mut self, index: u32, witness: WitnessMap<F>) {
60        self.stack.push(StackItem { index, witness });
61    }
62
63    /// Removes the top element from the stack and return its
64    pub fn pop(&mut self) -> Option<StackItem<F>> {
65        self.stack.pop()
66    }
67
68    /// Returns the top element of the stack, or `None` if it is empty
69    pub fn peek(&self) -> Option<&StackItem<F>> {
70        self.stack.last()
71    }
72
73    /// Returns the size of the stack
74    pub fn length(&self) -> usize {
75        self.stack.len()
76    }
77}
78
79impl<F: AcirField + Serialize + MsgpackTagged> WitnessStack<F> {
80    /// Serialize and compress.
81    pub fn serialize(&self) -> Result<Vec<u8>, WitnessStackError> {
82        let format = SerializationFormat::from_env()
83            .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
84        self.serialize_with_format(format.unwrap_or_default())
85    }
86
87    /// Serialize and compress with a given format.
88    pub fn serialize_with_format(
89        &self,
90        format: SerializationFormat,
91    ) -> Result<Vec<u8>, WitnessStackError> {
92        let buf = serialization::serialize_with_format(self, format)
93            .map_err(|e| WitnessStackError(SerializationError::Serialize(e)))?;
94
95        let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
96        let mut buf = Vec::new();
97        deflater
98            .read_to_end(&mut buf)
99            .map_err(|e| WitnessStackError(SerializationError::Compress(e)))?;
100
101        Ok(buf)
102    }
103}
104
105impl<F: AcirField + for<'a> Deserialize<'a> + MsgpackTagged> WitnessStack<F> {
106    /// Decompress and deserialize.
107    pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessStackError> {
108        let mut deflater = GzDecoder::new(buf);
109        let mut buf = Vec::new();
110        deflater
111            .read_to_end(&mut buf)
112            .map_err(|e| WitnessStackError(SerializationError::Decompress(e)))?;
113
114        serialization::deserialize_any_format(&buf)
115            .map_err(|e| WitnessStackError(SerializationError::Deserialize(e)))
116    }
117}
118
119impl<F> From<WitnessMap<F>> for WitnessStack<F> {
120    fn from(witness: WitnessMap<F>) -> Self {
121        let stack = vec![StackItem { index: 0, witness }];
122        Self { stack }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use crate::native_types::Witness;
130    use acir_field::FieldElement;
131
132    #[test]
133    fn test_round_trip_serialization() {
134        // Create a witness stack with multiple stack items
135        let mut stack = WitnessStack::default();
136
137        // First function call with some witnesses
138        let mut witness1 = WitnessMap::new();
139        witness1.insert(Witness(0), FieldElement::from(42u128));
140        witness1.insert(Witness(1), FieldElement::from(123u128));
141        stack.push(0, witness1);
142
143        // Second function call with different witnesses
144        let mut witness2 = WitnessMap::new();
145        witness2.insert(Witness(0), FieldElement::from(999u128));
146        witness2.insert(Witness(5), FieldElement::zero());
147        stack.push(1, witness2);
148
149        // Third function call
150        let mut witness3 = WitnessMap::new();
151        witness3.insert(Witness(10), FieldElement::one());
152        witness3.insert(Witness(20), FieldElement::from(u128::MAX));
153        stack.push(2, witness3);
154
155        // Serialize
156        let serialized = stack.serialize().expect("Serialization should succeed");
157
158        // Deserialize
159        let deserialized =
160            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
161
162        // Verify round trip
163        assert_eq!(stack, deserialized);
164    }
165
166    #[test]
167    fn test_round_trip_empty_witness_stack() {
168        // Test with an empty witness stack
169        let original = WitnessStack::<FieldElement>::default();
170
171        let serialized = original.serialize().expect("Serialization should succeed");
172        let deserialized =
173            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
174
175        assert_eq!(original, deserialized);
176    }
177
178    #[test]
179    fn test_round_trip_single_stack_item() {
180        // Test with a single stack item
181        let mut stack = WitnessStack::default();
182        let mut witness = WitnessMap::new();
183        witness.insert(Witness(0), FieldElement::from(12345u128));
184        witness.insert(Witness(1), FieldElement::from(67890u128));
185        stack.push(0, witness);
186
187        let serialized = stack.serialize().expect("Serialization should succeed");
188        let deserialized =
189            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
190
191        assert_eq!(stack, deserialized);
192    }
193
194    #[test]
195    fn test_round_trip_from_witness_map() {
196        // Test conversion from WitnessMap and serialization
197        let mut witness = WitnessMap::new();
198        witness.insert(Witness(0), FieldElement::from(111u128));
199        witness.insert(Witness(1), FieldElement::from(222u128));
200        witness.insert(Witness(2), FieldElement::from(333u128));
201
202        let original = WitnessStack::from(witness);
203
204        let serialized = original.serialize().expect("Serialization should succeed");
205        let deserialized =
206            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
207
208        assert_eq!(original, deserialized);
209    }
210
211    #[test]
212    fn test_round_trip_large_stack() {
213        // Test with many stack items
214        let mut stack = WitnessStack::default();
215
216        for i in 0..10 {
217            let mut witness = WitnessMap::new();
218            witness.insert(Witness(i), FieldElement::from(u128::from(i) * 100));
219            witness.insert(Witness(i + 100), FieldElement::from(u128::from(i) * 1000));
220            stack.push(i, witness);
221        }
222
223        let serialized = stack.serialize().expect("Serialization should succeed");
224        let deserialized =
225            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
226
227        assert_eq!(stack, deserialized);
228    }
229
230    #[test]
231    fn test_stack_operations() {
232        // Test stack operations work correctly
233        let mut stack = WitnessStack::default();
234
235        let mut witness1 = WitnessMap::new();
236        witness1.insert(Witness(0), FieldElement::from(1u128));
237        stack.push(0, witness1.clone());
238
239        let mut witness2 = WitnessMap::new();
240        witness2.insert(Witness(1), FieldElement::from(2u128));
241        stack.push(1, witness2.clone());
242
243        assert_eq!(stack.length(), 2);
244        assert_eq!(stack.peek().unwrap().index, 1);
245
246        let popped = stack.pop().unwrap();
247        assert_eq!(popped.index, 1);
248        assert_eq!(popped.witness, witness2);
249
250        assert_eq!(stack.length(), 1);
251        assert_eq!(stack.peek().unwrap().index, 0);
252        assert_eq!(stack.length(), 1);
253    }
254}