acir/native_types/
witness_map.rs1use std::{
2 collections::{BTreeMap, btree_map},
3 io::Read,
4 ops::Index,
5};
6
7use acir_field::AcirField;
8use flate2::Compression;
9use flate2::bufread::GzDecoder;
10use flate2::bufread::GzEncoder;
11use msgpack_tagged::MsgpackTagged;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15use crate::{SerializationFormat, native_types::Witness, serialization};
16
17#[derive(Debug, Error)]
18enum SerializationError {
19 #[error("error compressing witness map: {0}")]
20 Compress(std::io::Error),
21
22 #[error("error decompressing witness map: {0}")]
23 Decompress(std::io::Error),
24
25 #[error("error serializing witness map: {0}")]
26 Serialize(std::io::Error),
27
28 #[error("error deserializing witness map: {0}")]
29 Deserialize(std::io::Error),
30}
31
32#[derive(Debug, Error)]
33#[error(transparent)]
34pub struct WitnessMapError(#[from] SerializationError);
35
36#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38#[derive(Serialize, Deserialize, MsgpackTagged)]
39#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
40pub struct WitnessMap<F>(BTreeMap<Witness, F>);
41
42impl<F> WitnessMap<F> {
43 pub fn new() -> Self {
44 Self(BTreeMap::new())
45 }
46 pub fn get(&self, witness: &Witness) -> Option<&F> {
47 self.0.get(witness)
48 }
49 pub fn get_index(&self, index: u32) -> Option<&F> {
50 self.0.get(&index.into())
51 }
52 pub fn contains_key(&self, key: &Witness) -> bool {
53 self.0.contains_key(key)
54 }
55 pub fn insert(&mut self, key: Witness, value: F) -> Option<F> {
56 self.0.insert(key, value)
57 }
58 pub fn entry(&mut self, key: Witness) -> btree_map::Entry<'_, Witness, F> {
59 self.0.entry(key)
60 }
61}
62
63impl<F> Index<&Witness> for WitnessMap<F> {
64 type Output = F;
65
66 fn index(&self, index: &Witness) -> &Self::Output {
67 &self.0[index]
68 }
69}
70
71pub struct IntoIter<F>(btree_map::IntoIter<Witness, F>);
72
73impl<F> Iterator for IntoIter<F> {
74 type Item = (Witness, F);
75
76 fn next(&mut self) -> Option<Self::Item> {
77 self.0.next()
78 }
79}
80
81impl<F> IntoIterator for WitnessMap<F> {
82 type Item = (Witness, F);
83 type IntoIter = IntoIter<F>;
84
85 fn into_iter(self) -> Self::IntoIter {
86 IntoIter(self.0.into_iter())
87 }
88}
89
90impl<F> From<BTreeMap<Witness, F>> for WitnessMap<F> {
91 fn from(value: BTreeMap<Witness, F>) -> Self {
92 Self(value)
93 }
94}
95
96impl<F: AcirField + Serialize + MsgpackTagged> WitnessMap<F> {
97 pub fn serialize(&self) -> Result<Vec<u8>, WitnessMapError> {
99 let format = SerializationFormat::from_env()
100 .map_err(|err| SerializationError::Serialize(std::io::Error::other(err)))?;
101 self.serialize_with_format(format.unwrap_or_default())
102 }
103
104 pub fn serialize_with_format(
106 &self,
107 format: SerializationFormat,
108 ) -> Result<Vec<u8>, WitnessMapError> {
109 let buf = serialization::serialize_with_format(self, format)
110 .map_err(|e| WitnessMapError(SerializationError::Serialize(e)))?;
111
112 let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best());
113 let mut buf = Vec::new();
114 deflater
115 .read_to_end(&mut buf)
116 .map_err(|e| WitnessMapError(SerializationError::Compress(e)))?;
117
118 Ok(buf)
119 }
120}
121
122impl<F: AcirField + for<'a> Deserialize<'a> + MsgpackTagged> WitnessMap<F> {
123 pub fn deserialize(buf: &[u8]) -> Result<Self, WitnessMapError> {
125 let mut deflater = GzDecoder::new(buf);
126 let mut buf = Vec::new();
127 deflater
128 .read_to_end(&mut buf)
129 .map_err(|e| WitnessMapError(SerializationError::Decompress(e)))?;
130
131 serialization::deserialize_any_format(&buf)
132 .map_err(|e| WitnessMapError(SerializationError::Deserialize(e)))
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use acir_field::FieldElement;
140
141 #[test]
142 fn test_round_trip_serialization() {
143 let mut original = WitnessMap::new();
145 original.insert(Witness(0), FieldElement::from(42u128));
146 original.insert(Witness(1), FieldElement::from(123u128));
147 original.insert(Witness(5), FieldElement::from(999u128));
148 original.insert(Witness(10), FieldElement::zero());
149 original.insert(Witness(100), FieldElement::one());
150
151 let serialized = original.serialize().expect("Serialization should succeed");
153
154 let deserialized =
156 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
157
158 assert_eq!(original, deserialized);
160 }
161
162 #[test]
163 fn test_round_trip_empty_witness_map() {
164 let original = WitnessMap::<FieldElement>::new();
166
167 let serialized = original.serialize().expect("Serialization should succeed");
168 let deserialized =
169 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
170
171 assert_eq!(original, deserialized);
172 }
173
174 #[test]
175 fn test_round_trip_single_entry() {
176 let mut original = WitnessMap::new();
178 original.insert(Witness(0), FieldElement::from(12345u128));
179
180 let serialized = original.serialize().expect("Serialization should succeed");
181 let deserialized =
182 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
183
184 assert_eq!(original, deserialized);
185 }
186
187 #[test]
188 fn test_round_trip_large_field_elements() {
189 let mut original = WitnessMap::new();
191 original.insert(Witness(0), FieldElement::from(u128::MAX));
192 original.insert(Witness(1), -FieldElement::one());
193 original.insert(Witness(2), FieldElement::from(u128::MAX / 2));
194
195 let serialized = original.serialize().expect("Serialization should succeed");
196 let deserialized =
197 WitnessMap::deserialize(&serialized).expect("Deserialization should succeed");
198
199 assert_eq!(original, deserialized);
200 }
201}