acir/native_types/
witness_stack.rs

1use std::io::Read;
2
3use acir_field::AcirField;
4use flate2::Compression;
5use flate2::bufread::GzDecoder;
6use flate2::bufread::GzEncoder;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10use crate::SerializationFormat;
11use crate::serialization;
12
13use super::WitnessMap;
14
15#[derive(Debug, Error)]
16enum SerializationError {
17    #[error("error compressing witness stack: {0}")]
18    Compress(std::io::Error),
19
20    #[error("error decompressing witness stack: {0}")]
21    Decompress(std::io::Error),
22
23    #[error("error serializing witness stack: {0}")]
24    Serialize(std::io::Error),
25
26    #[error("error deserializing witness stack: {0}")]
27    Deserialize(std::io::Error),
28}
29
30/// Native error for serializing/deserializing a witness stack.
31#[derive(Debug, Error)]
32#[error(transparent)]
33pub struct WitnessStackError(#[from] SerializationError);
34
35/// An ordered set of witness maps for separate circuits
36#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
37#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
38pub struct WitnessStack<F> {
39    stack: Vec<StackItem<F>>,
40}
41
42#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
43#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
44pub struct StackItem<F> {
45    /// Index into a [crate::circuit::Program] function list for which we have an associated witness
46    pub index: u32,
47    /// A full witness for the respective constraint system specified by the index
48    pub witness: WitnessMap<F>,
49}
50
51impl<F> WitnessStack<F> {
52    /// Append an element to the top of the stack
53    pub fn push(&mut self, index: u32, witness: WitnessMap<F>) {
54        self.stack.push(StackItem { index, witness });
55    }
56
57    /// Removes the top element from the stack and return its
58    pub fn pop(&mut self) -> Option<StackItem<F>> {
59        self.stack.pop()
60    }
61
62    /// Returns the top element of the stack, or `None` if it is empty
63    pub fn peek(&self) -> Option<&StackItem<F>> {
64        self.stack.last()
65    }
66
67    /// Returns the size of the stack
68    pub fn length(&self) -> usize {
69        self.stack.len()
70    }
71}
72
73impl<F: AcirField + Serialize> WitnessStack<F> {
74    /// Serialize and compress.
75    pub fn serialize(&self) -> Result<Vec<u8>, WitnessStackError> {
76        let format = SerializationFormat::from_env()
77            .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
78        self.serialize_with_format(format.unwrap_or_default())
79    }
80
81    /// Serialize and compress with a given format.
82    pub fn serialize_with_format(
83        &self,
84        format: SerializationFormat,
85    ) -> Result<Vec<u8>, WitnessStackError> {
86        let buf = serialization::serialize_with_format(self, format)
87            .map_err(|e| WitnessStackError(SerializationError::Serialize(e)))?;
88
89        let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
90        let mut buf = Vec::new();
91        deflater
92            .read_to_end(&mut buf)
93            .map_err(|e| WitnessStackError(SerializationError::Compress(e)))?;
94
95        Ok(buf)
96    }
97}
98
99impl<F: AcirField + for<'a> Deserialize<'a>> WitnessStack<F> {
100    /// Decompress and deserialize.
101    pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessStackError> {
102        let mut deflater = GzDecoder::new(buf);
103        let mut buf = Vec::new();
104        deflater
105            .read_to_end(&mut buf)
106            .map_err(|e| WitnessStackError(SerializationError::Decompress(e)))?;
107
108        serialization::deserialize_any_format(&buf)
109            .map_err(|e| WitnessStackError(SerializationError::Deserialize(e)))
110    }
111}
112
113impl<F> From<WitnessMap<F>> for WitnessStack<F> {
114    fn from(witness: WitnessMap<F>) -> Self {
115        let stack = vec![StackItem { index: 0, witness }];
116        Self { stack }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::native_types::Witness;
124    use acir_field::FieldElement;
125
126    #[test]
127    fn test_round_trip_serialization() {
128        // Create a witness stack with multiple stack items
129        let mut stack = WitnessStack::default();
130
131        // First function call with some witnesses
132        let mut witness1 = WitnessMap::new();
133        witness1.insert(Witness(0), FieldElement::from(42u128));
134        witness1.insert(Witness(1), FieldElement::from(123u128));
135        stack.push(0, witness1);
136
137        // Second function call with different witnesses
138        let mut witness2 = WitnessMap::new();
139        witness2.insert(Witness(0), FieldElement::from(999u128));
140        witness2.insert(Witness(5), FieldElement::zero());
141        stack.push(1, witness2);
142
143        // Third function call
144        let mut witness3 = WitnessMap::new();
145        witness3.insert(Witness(10), FieldElement::one());
146        witness3.insert(Witness(20), FieldElement::from(u128::MAX));
147        stack.push(2, witness3);
148
149        // Serialize
150        let serialized = stack.serialize().expect("Serialization should succeed");
151
152        // Deserialize
153        let deserialized =
154            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
155
156        // Verify round trip
157        assert_eq!(stack, deserialized);
158    }
159
160    #[test]
161    fn test_round_trip_empty_witness_stack() {
162        // Test with an empty witness stack
163        let original = WitnessStack::<FieldElement>::default();
164
165        let serialized = original.serialize().expect("Serialization should succeed");
166        let deserialized =
167            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
168
169        assert_eq!(original, deserialized);
170    }
171
172    #[test]
173    fn test_round_trip_single_stack_item() {
174        // Test with a single stack item
175        let mut stack = WitnessStack::default();
176        let mut witness = WitnessMap::new();
177        witness.insert(Witness(0), FieldElement::from(12345u128));
178        witness.insert(Witness(1), FieldElement::from(67890u128));
179        stack.push(0, witness);
180
181        let serialized = stack.serialize().expect("Serialization should succeed");
182        let deserialized =
183            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
184
185        assert_eq!(stack, deserialized);
186    }
187
188    #[test]
189    fn test_round_trip_from_witness_map() {
190        // Test conversion from WitnessMap and serialization
191        let mut witness = WitnessMap::new();
192        witness.insert(Witness(0), FieldElement::from(111u128));
193        witness.insert(Witness(1), FieldElement::from(222u128));
194        witness.insert(Witness(2), FieldElement::from(333u128));
195
196        let original = WitnessStack::from(witness);
197
198        let serialized = original.serialize().expect("Serialization should succeed");
199        let deserialized =
200            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
201
202        assert_eq!(original, deserialized);
203    }
204
205    #[test]
206    fn test_round_trip_large_stack() {
207        // Test with many stack items
208        let mut stack = WitnessStack::default();
209
210        for i in 0..10 {
211            let mut witness = WitnessMap::new();
212            witness.insert(Witness(i), FieldElement::from(u128::from(i) * 100));
213            witness.insert(Witness(i + 100), FieldElement::from(u128::from(i) * 1000));
214            stack.push(i, witness);
215        }
216
217        let serialized = stack.serialize().expect("Serialization should succeed");
218        let deserialized =
219            WitnessStack::deserialize(&serialized).expect("Deserialization should succeed");
220
221        assert_eq!(stack, deserialized);
222    }
223
224    #[test]
225    fn test_stack_operations() {
226        // Test stack operations work correctly
227        let mut stack = WitnessStack::default();
228
229        let mut witness1 = WitnessMap::new();
230        witness1.insert(Witness(0), FieldElement::from(1u128));
231        stack.push(0, witness1.clone());
232
233        let mut witness2 = WitnessMap::new();
234        witness2.insert(Witness(1), FieldElement::from(2u128));
235        stack.push(1, witness2.clone());
236
237        assert_eq!(stack.length(), 2);
238        assert_eq!(stack.peek().unwrap().index, 1);
239
240        let popped = stack.pop().unwrap();
241        assert_eq!(popped.index, 1);
242        assert_eq!(popped.witness, witness2);
243
244        assert_eq!(stack.length(), 1);
245        assert_eq!(stack.peek().unwrap().index, 0);
246        assert_eq!(stack.length(), 1);
247    }
248}