acir/native_types/
witness_map.rs

1use std::{
2    collections::{BTreeMap, btree_map},
3    io::Read,
4    ops::Index,
5};
6
7use acir_field::AcirField;
8use flate2::Compression;
9use flate2::bufread::GzDecoder;
10use flate2::bufread::GzEncoder;
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14use crate::{SerializationFormat, native_types::Witness, serialization};
15
16#[derive(Debug, Error)]
17enum SerializationError {
18    #[error("error compressing witness map: {0}")]
19    Compress(std::io::Error),
20
21    #[error("error decompressing witness map: {0}")]
22    Decompress(std::io::Error),
23
24    #[error("error serializing witness map: {0}")]
25    Serialize(std::io::Error),
26
27    #[error("error deserializing witness map: {0}")]
28    Deserialize(std::io::Error),
29}
30
31#[derive(Debug, Error)]
32#[error(transparent)]
33pub struct WitnessMapError(#[from] SerializationError);
34
35/// A map from the witnesses in a constraint system to the field element values
36#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
37#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
38pub struct WitnessMap<F>(BTreeMap<Witness, F>);
39
40impl<F> WitnessMap<F> {
41    pub fn new() -> Self {
42        Self(BTreeMap::new())
43    }
44    pub fn get(&self, witness: &Witness) -> Option<&F> {
45        self.0.get(witness)
46    }
47    pub fn get_index(&self, index: u32) -> Option<&F> {
48        self.0.get(&index.into())
49    }
50    pub fn contains_key(&self, key: &Witness) -> bool {
51        self.0.contains_key(key)
52    }
53    pub fn insert(&mut self, key: Witness, value: F) -> Option<F> {
54        self.0.insert(key, value)
55    }
56    pub fn entry(&mut self, key: Witness) -> btree_map::Entry<'_, Witness, F> {
57        self.0.entry(key)
58    }
59}
60
61impl<F> Index<&Witness> for WitnessMap<F> {
62    type Output = F;
63
64    fn index(&self, index: &Witness) -> &Self::Output {
65        &self.0[index]
66    }
67}
68
69pub struct IntoIter<F>(btree_map::IntoIter<Witness, F>);
70
71impl<F> Iterator for IntoIter<F> {
72    type Item = (Witness, F);
73
74    fn next(&mut self) -> Option<Self::Item> {
75        self.0.next()
76    }
77}
78
79impl<F> IntoIterator for WitnessMap<F> {
80    type Item = (Witness, F);
81    type IntoIter = IntoIter<F>;
82
83    fn into_iter(self) -> Self::IntoIter {
84        IntoIter(self.0.into_iter())
85    }
86}
87
88impl<F> From<BTreeMap<Witness, F>> for WitnessMap<F> {
89    fn from(value: BTreeMap<Witness, F>) -> Self {
90        Self(value)
91    }
92}
93
94impl<F: AcirField + Serialize> WitnessMap<F> {
95    /// Serialize and compress.
96    pub fn serialize(&self) -> Result<Vec<u8>, WitnessMapError> {
97        let format = SerializationFormat::from_env()
98            .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
99        self.serialize_with_format(format.unwrap_or_default())
100    }
101
102    /// Serialize and compress with a given format.
103    pub fn serialize_with_format(
104        &self,
105        format: SerializationFormat,
106    ) -> Result<Vec<u8>, WitnessMapError> {
107        let buf = serialization::serialize_with_format(self, format)
108            .map_err(|e| WitnessMapError(SerializationError::Serialize(e)))?;
109
110        let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
111        let mut buf = Vec::new();
112        deflater
113            .read_to_end(&mut buf)
114            .map_err(|e| WitnessMapError(SerializationError::Compress(e)))?;
115
116        Ok(buf)
117    }
118}
119
120impl<F: AcirField + for<'a> Deserialize<'a>> WitnessMap<F> {
121    /// Decompress and deserialize.
122    pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessMapError> {
123        let mut deflater = GzDecoder::new(buf);
124        let mut buf = Vec::new();
125        deflater
126            .read_to_end(&mut buf)
127            .map_err(|e| WitnessMapError(SerializationError::Decompress(e)))?;
128
129        serialization::deserialize_any_format(&buf)
130            .map_err(|e| WitnessMapError(SerializationError::Deserialize(e)))
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use acir_field::FieldElement;
138
139    #[test]
140    fn test_round_trip_serialization() {
141        // Create a witness map with several entries
142        let mut original = WitnessMap::new();
143        original.insert(Witness(0), FieldElement::from(42u128));
144        original.insert(Witness(1), FieldElement::from(123u128));
145        original.insert(Witness(5), FieldElement::from(999u128));
146        original.insert(Witness(10), FieldElement::zero());
147        original.insert(Witness(100), FieldElement::one());
148
149        // Serialize
150        let serialized = original.serialize().expect("Serialization should succeed");
151
152        // Deserialize
153        let deserialized =
154            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
155
156        // Verify round trip
157        assert_eq!(original, deserialized);
158    }
159
160    #[test]
161    fn test_round_trip_empty_witness_map() {
162        // Test with an empty witness map
163        let original = WitnessMap::<FieldElement>::new();
164
165        let serialized = original.serialize().expect("Serialization should succeed");
166        let deserialized =
167            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
168
169        assert_eq!(original, deserialized);
170    }
171
172    #[test]
173    fn test_round_trip_single_entry() {
174        // Test with a single entry
175        let mut original = WitnessMap::new();
176        original.insert(Witness(0), FieldElement::from(12345u128));
177
178        let serialized = original.serialize().expect("Serialization should succeed");
179        let deserialized =
180            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
181
182        assert_eq!(original, deserialized);
183    }
184
185    #[test]
186    fn test_round_trip_large_field_elements() {
187        // Test with large field elements
188        let mut original = WitnessMap::new();
189        original.insert(Witness(0), FieldElement::from(u128::MAX));
190        original.insert(Witness(1), -FieldElement::one());
191        original.insert(Witness(2), FieldElement::from(u128::MAX / 2));
192
193        let serialized = original.serialize().expect("Serialization should succeed");
194        let deserialized =
195            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
196
197        assert_eq!(original, deserialized);
198    }
199}