acir/native_types/
witness_map.rs

1use std::{
2    collections::{BTreeMap, btree_map},
3    io::Read,
4    ops::Index,
5};
6
7use acir_field::AcirField;
8use flate2::Compression;
9use flate2::bufread::GzDecoder;
10use flate2::bufread::GzEncoder;
11use msgpack_tagged::MsgpackTagged;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15use crate::{SerializationFormat, native_types::Witness, serialization};
16
17#[derive(Debug, Error)]
18enum SerializationError {
19    #[error("error compressing witness map: {0}")]
20    Compress(std::io::Error),
21
22    #[error("error decompressing witness map: {0}")]
23    Decompress(std::io::Error),
24
25    #[error("error serializing witness map: {0}")]
26    Serialize(std::io::Error),
27
28    #[error("error deserializing witness map: {0}")]
29    Deserialize(std::io::Error),
30}
31
32#[derive(Debug, Error)]
33#[error(transparent)]
34pub struct WitnessMapError(#[from] SerializationError);
35
36/// A map from the witnesses in a constraint system to the field element values
37#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38#[derive(Serialize, Deserialize, MsgpackTagged)]
39#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
40pub struct WitnessMap<F>(BTreeMap<Witness, F>);
41
42impl<F> WitnessMap<F> {
43    pub fn new() -> Self {
44        Self(BTreeMap::new())
45    }
46    pub fn get(&self, witness: &Witness) -> Option<&F> {
47        self.0.get(witness)
48    }
49    pub fn get_index(&self, index: u32) -> Option<&F> {
50        self.0.get(&index.into())
51    }
52    pub fn contains_key(&self, key: &Witness) -> bool {
53        self.0.contains_key(key)
54    }
55    pub fn insert(&mut self, key: Witness, value: F) -> Option<F> {
56        self.0.insert(key, value)
57    }
58    pub fn entry(&mut self, key: Witness) -> btree_map::Entry<'_, Witness, F> {
59        self.0.entry(key)
60    }
61}
62
63impl<F> Index<&Witness> for WitnessMap<F> {
64    type Output = F;
65
66    fn index(&self, index: &Witness) -> &Self::Output {
67        &self.0[index]
68    }
69}
70
71pub struct IntoIter<F>(btree_map::IntoIter<Witness, F>);
72
73impl<F> Iterator for IntoIter<F> {
74    type Item = (Witness, F);
75
76    fn next(&mut self) -> Option<Self::Item> {
77        self.0.next()
78    }
79}
80
81impl<F> IntoIterator for WitnessMap<F> {
82    type Item = (Witness, F);
83    type IntoIter = IntoIter<F>;
84
85    fn into_iter(self) -> Self::IntoIter {
86        IntoIter(self.0.into_iter())
87    }
88}
89
90impl<F> From<BTreeMap<Witness, F>> for WitnessMap<F> {
91    fn from(value: BTreeMap<Witness, F>) -> Self {
92        Self(value)
93    }
94}
95
96impl<F: AcirField + Serialize + MsgpackTagged> WitnessMap<F> {
97    /// Serialize and compress.
98    pub fn serialize(&self) -> Result<Vec<u8>, WitnessMapError> {
99        let format = SerializationFormat::from_env()
100            .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
101        self.serialize_with_format(format.unwrap_or_default())
102    }
103
104    /// Serialize and compress with a given format.
105    pub fn serialize_with_format(
106        &self,
107        format: SerializationFormat,
108    ) -> Result<Vec<u8>, WitnessMapError> {
109        let buf = serialization::serialize_with_format(self, format)
110            .map_err(|e| WitnessMapError(SerializationError::Serialize(e)))?;
111
112        let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
113        let mut buf = Vec::new();
114        deflater
115            .read_to_end(&mut buf)
116            .map_err(|e| WitnessMapError(SerializationError::Compress(e)))?;
117
118        Ok(buf)
119    }
120}
121
122impl<F: AcirField + for<'a> Deserialize<'a> + MsgpackTagged> WitnessMap<F> {
123    /// Decompress and deserialize.
124    pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessMapError> {
125        let mut deflater = GzDecoder::new(buf);
126        let mut buf = Vec::new();
127        deflater
128            .read_to_end(&mut buf)
129            .map_err(|e| WitnessMapError(SerializationError::Decompress(e)))?;
130
131        serialization::deserialize_any_format(&buf)
132            .map_err(|e| WitnessMapError(SerializationError::Deserialize(e)))
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use acir_field::FieldElement;
140
141    #[test]
142    fn test_round_trip_serialization() {
143        // Create a witness map with several entries
144        let mut original = WitnessMap::new();
145        original.insert(Witness(0), FieldElement::from(42u128));
146        original.insert(Witness(1), FieldElement::from(123u128));
147        original.insert(Witness(5), FieldElement::from(999u128));
148        original.insert(Witness(10), FieldElement::zero());
149        original.insert(Witness(100), FieldElement::one());
150
151        // Serialize
152        let serialized = original.serialize().expect("Serialization should succeed");
153
154        // Deserialize
155        let deserialized =
156            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
157
158        // Verify round trip
159        assert_eq!(original, deserialized);
160    }
161
162    #[test]
163    fn test_round_trip_empty_witness_map() {
164        // Test with an empty witness map
165        let original = WitnessMap::<FieldElement>::new();
166
167        let serialized = original.serialize().expect("Serialization should succeed");
168        let deserialized =
169            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
170
171        assert_eq!(original, deserialized);
172    }
173
174    #[test]
175    fn test_round_trip_single_entry() {
176        // Test with a single entry
177        let mut original = WitnessMap::new();
178        original.insert(Witness(0), FieldElement::from(12345u128));
179
180        let serialized = original.serialize().expect("Serialization should succeed");
181        let deserialized =
182            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
183
184        assert_eq!(original, deserialized);
185    }
186
187    #[test]
188    fn test_round_trip_large_field_elements() {
189        // Test with large field elements
190        let mut original = WitnessMap::new();
191        original.insert(Witness(0), FieldElement::from(u128::MAX));
192        original.insert(Witness(1), -FieldElement::one());
193        original.insert(Witness(2), FieldElement::from(u128::MAX / 2));
194
195        let serialized = original.serialize().expect("Serialization should succeed");
196        let deserialized =
197            WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
198
199        assert_eq!(original, deserialized);
200    }
201}