acir/native_types/
witness_map.rs1use std::{
2 collections::{BTreeMap, btree_map},
3 io::Read,
4 ops::Index,
5};
6
7use acir_field::AcirField;
8use flate2::Compression;
9use flate2::bufread::GzDecoder;
10use flate2::bufread::GzEncoder;
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14use crate::{SerializationFormat, native_types::Witness, serialization};
15
16#[derive(Debug, Error)]
17enum SerializationError {
18 #[error("error compressing witness map: {0}")]
19 Compress(std::io::Error),
20
21 #[error("error decompressing witness map: {0}")]
22 Decompress(std::io::Error),
23
24 #[error("error serializing witness map: {0}")]
25 Serialize(std::io::Error),
26
27 #[error("error deserializing witness map: {0}")]
28 Deserialize(std::io::Error),
29}
30
31#[derive(Debug, Error)]
32#[error(transparent)]
33pub struct WitnessMapError(#[from] SerializationError);
34
35#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
37#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
38pub struct WitnessMap<F>(BTreeMap<Witness, F>);
39
40impl<F> WitnessMap<F> {
41 pub fn new() -> Self {
42 Self(BTreeMap::new())
43 }
44 pub fn get(&self, witness: &Witness) -> Option<&F> {
45 self.0.get(witness)
46 }
47 pub fn get_index(&self, index: u32) -> Option<&F> {
48 self.0.get(&index.into())
49 }
50 pub fn contains_key(&self, key: &Witness) -> bool {
51 self.0.contains_key(key)
52 }
53 pub fn insert(&mut self, key: Witness, value: F) -> Option<F> {
54 self.0.insert(key, value)
55 }
56 pub fn entry(&mut self, key: Witness) -> btree_map::Entry<'_, Witness, F> {
57 self.0.entry(key)
58 }
59}
60
61impl<F> Index<&Witness> for WitnessMap<F> {
62 type Output = F;
63
64 fn index(&self, index: &Witness) -> &Self::Output {
65 &self.0[index]
66 }
67}
68
69pub struct IntoIter<F>(btree_map::IntoIter<Witness, F>);
70
71impl<F> Iterator for IntoIter<F> {
72 type Item = (Witness, F);
73
74 fn next(&mut self) -> Option<Self::Item> {
75 self.0.next()
76 }
77}
78
79impl<F> IntoIterator for WitnessMap<F> {
80 type Item = (Witness, F);
81 type IntoIter = IntoIter<F>;
82
83 fn into_iter(self) -> Self::IntoIter {
84 IntoIter(self.0.into_iter())
85 }
86}
87
88impl<F> From<BTreeMap<Witness, F>> for WitnessMap<F> {
89 fn from(value: BTreeMap<Witness, F>) -> Self {
90 Self(value)
91 }
92}
93
94impl<F: AcirField + Serialize> WitnessMap<F> {
95 pub fn serialize(&self) -> Result<Vec<u8>, WitnessMapError> {
97 let format = SerializationFormat::from_env()
98 .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
99 self.serialize_with_format(format.unwrap_or_default())
100 }
101
102 pub fn serialize_with_format(
104 &self,
105 format: SerializationFormat,
106 ) -> Result<Vec<u8>, WitnessMapError> {
107 let buf = serialization::serialize_with_format(self, format)
108 .map_err(|e| WitnessMapError(SerializationError::Serialize(e)))?;
109
110 let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
111 let mut buf = Vec::new();
112 deflater
113 .read_to_end(&mut buf)
114 .map_err(|e| WitnessMapError(SerializationError::Compress(e)))?;
115
116 Ok(buf)
117 }
118}
119
120impl<F: AcirField + for<'a> Deserialize<'a>> WitnessMap<F> {
121 pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessMapError> {
123 let mut deflater = GzDecoder::new(buf);
124 let mut buf = Vec::new();
125 deflater
126 .read_to_end(&mut buf)
127 .map_err(|e| WitnessMapError(SerializationError::Decompress(e)))?;
128
129 serialization::deserialize_any_format(&buf)
130 .map_err(|e| WitnessMapError(SerializationError::Deserialize(e)))
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use acir_field::FieldElement;
138
139 #[test]
140 fn test_round_trip_serialization() {
141 let mut original = WitnessMap::new();
143 original.insert(Witness(0), FieldElement::from(42u128));
144 original.insert(Witness(1), FieldElement::from(123u128));
145 original.insert(Witness(5), FieldElement::from(999u128));
146 original.insert(Witness(10), FieldElement::zero());
147 original.insert(Witness(100), FieldElement::one());
148
149 let serialized = original.serialize().expect("Serialization should succeed");
151
152 let deserialized =
154 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
155
156 assert_eq!(original, deserialized);
158 }
159
160 #[test]
161 fn test_round_trip_empty_witness_map() {
162 let original = WitnessMap::<FieldElement>::new();
164
165 let serialized = original.serialize().expect("Serialization should succeed");
166 let deserialized =
167 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
168
169 assert_eq!(original, deserialized);
170 }
171
172 #[test]
173 fn test_round_trip_single_entry() {
174 let mut original = WitnessMap::new();
176 original.insert(Witness(0), FieldElement::from(12345u128));
177
178 let serialized = original.serialize().expect("Serialization should succeed");
179 let deserialized =
180 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
181
182 assert_eq!(original, deserialized);
183 }
184
185 #[test]
186 fn test_round_trip_large_field_elements() {
187 let mut original = WitnessMap::new();
189 original.insert(Witness(0), FieldElement::from(u128::MAX));
190 original.insert(Witness(1), -FieldElement::one());
191 original.insert(Witness(2), FieldElement::from(u128::MAX / 2));
192
193 let serialized = original.serialize().expect("Serialization should succeed");
194 let deserialized =
195 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
196
197 assert_eq!(original, deserialized);
198 }
199}