1use num_enum::{IntoPrimitive, TryFromPrimitive};
4use serde::{Deserialize, Serialize};
5use std::str::FromStr;
6use strum_macros::EnumString;
7
8const FORMAT_ENV_VAR: &str = "NOIR_SERIALIZATION_FORMAT";
9
10#[derive(
12 Debug, Default, Clone, Copy, IntoPrimitive, TryFromPrimitive, EnumString, PartialEq, Eq,
13)]
14#[strum(serialize_all = "kebab-case")]
15#[repr(u8)]
16pub enum Format {
17 Msgpack = 2,
19 #[default]
21 MsgpackCompact = 3,
22}
23
24impl Format {
25 pub fn from_env() -> Result<Option<Self>, String> {
32 let Ok(format) = std::env::var(FORMAT_ENV_VAR) else {
33 return Ok(None);
34 };
35 Self::from_str(&format)
36 .map(Some)
37 .map_err(|e| format!("unknown format '{format}' in {FORMAT_ENV_VAR}: {e}"))
38 }
39}
40
41pub(crate) fn msgpack_serialize<T: Serialize>(
51 value: &T,
52 compact: bool,
53) -> std::io::Result<Vec<u8>> {
54 let mut buf = Vec::new();
60
61 let serializer = rmp_serde::Serializer::new(&mut buf)
62 .with_bytes(rmp_serde::config::BytesMode::ForceIterables);
63
64 let result = if compact {
65 value.serialize(&mut serializer.with_struct_tuple())
66 } else {
67 value.serialize(&mut serializer.with_struct_map())
68 };
69
70 match result {
71 Ok(()) => Ok(buf),
72 Err(e) => Err(std::io::Error::other(e)),
73 }
74}
75
76pub(crate) fn msgpack_deserialize<T: for<'a> Deserialize<'a>>(buf: &[u8]) -> std::io::Result<T> {
78 rmp_serde::from_slice(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
79}
80
81pub(crate) fn deserialize_any_format<T>(buf: &[u8]) -> std::io::Result<T>
83where
84 T: for<'a> Deserialize<'a>,
85{
86 let Some(format_byte) = buf.first() else {
87 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "empty buffer"));
88 };
89
90 match Format::try_from(*format_byte) {
91 Ok(Format::Msgpack) | Ok(Format::MsgpackCompact) => msgpack_deserialize(&buf[1..]),
92 Err(msg) => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, msg.to_string())),
93 }
94}
95
96pub(crate) fn serialize_with_format<T>(value: &T, format: Format) -> std::io::Result<Vec<u8>>
97where
98 T: Serialize,
99{
100 let mut buf = match format {
102 Format::Msgpack => msgpack_serialize(value, false)?,
103 Format::MsgpackCompact => msgpack_serialize(value, true)?,
104 };
105 let mut res = vec![format.into()];
106 res.append(&mut buf);
107 Ok(res)
108}
109
110#[cfg(test)]
111mod tests {
112 use brillig::{BitSize, HeapArray, IntegerBitSize, ValueOrArray, lengths::SemiFlattenedLength};
113 use std::str::FromStr;
114
115 use crate::{
116 native_types::Witness,
117 serialization::{Format, msgpack_deserialize, msgpack_serialize},
118 };
119
120 mod version1 {
121 use serde::{Deserialize, Serialize};
122
123 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
124 pub(crate) enum Foo {
125 Case0 { d: u32 },
126 Case1 { a: u64, b: bool },
127 Case2 { a: i32 },
128 Case3 { a: bool },
129 Case4 { a: Box<Foo> },
130 Case5 { a: u32, b: Option<u32> },
131 }
132 }
133
134 mod version2 {
135 use serde::{Deserialize, Serialize};
136
137 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
139 pub(crate) enum Foo {
140 Case1 {
144 a: u64,
145 b: bool,
146 },
147 Case2 {
149 b: String,
150 a: i32,
151 },
152 Case3 {
154 a: bool,
155 b: String,
156 },
157 Case5 {
159 a: u32,
160 },
161 Case4 {
163 #[serde(rename = "a")]
164 c: Box<Foo>,
165 },
166 Case6 {
168 b: i64,
169 },
170 Case7 {
172 c: bool,
173 },
174 }
175 }
176
177 #[test]
183 fn msgpack_serialize_backwards_compatibility() {
184 let cases = vec![
185 (version2::Foo::Case1 { b: true, a: 1 }, version1::Foo::Case1 { b: true, a: 1 }),
186 (version2::Foo::Case2 { b: "prefix".into(), a: 2 }, version1::Foo::Case2 { a: 2 }),
187 (
188 version2::Foo::Case3 { a: true, b: "suffix".into() },
189 version1::Foo::Case3 { a: true },
190 ),
191 (
192 version2::Foo::Case4 { c: Box::new(version2::Foo::Case1 { a: 4, b: false }) },
193 version1::Foo::Case4 { a: Box::new(version1::Foo::Case1 { a: 4, b: false }) },
194 ),
195 (version2::Foo::Case5 { a: 5 }, version1::Foo::Case5 { a: 5, b: None }),
196 ];
197
198 for (i, (v2, v1)) in cases.into_iter().enumerate() {
199 let bz = msgpack_serialize(&v2, false).unwrap();
200 let v = msgpack_deserialize::<version1::Foo>(&bz)
201 .unwrap_or_else(|e| panic!("case {i} failed: {e}"));
202 assert_eq!(v, v1);
203 }
204 }
205
206 #[test]
215 fn msgpack_serialize_compact_backwards_compatibility() {
216 let cases = vec![
217 (version2::Foo::Case1 { b: true, a: 1 }, version1::Foo::Case1 { b: true, a: 1 }, None),
218 (
219 version2::Foo::Case2 { b: "prefix".into(), a: 2 },
220 version1::Foo::Case2 { a: 2 },
221 Some("wrong msgpack marker FixStr(6)"),
222 ),
223 (
224 version2::Foo::Case3 { a: true, b: "suffix".into() },
225 version1::Foo::Case3 { a: true },
226 Some("array had incorrect length, expected 1"),
227 ),
228 (
229 version2::Foo::Case4 { c: Box::new(version2::Foo::Case1 { a: 4, b: false }) },
230 version1::Foo::Case4 { a: Box::new(version1::Foo::Case1 { a: 4, b: false }) },
231 None,
232 ),
233 (
234 version2::Foo::Case5 { a: 5 },
235 version1::Foo::Case5 { a: 5, b: None },
236 Some("invalid length 1, expected struct variant Foo::Case5 with 2 elements"),
237 ),
238 ];
239
240 for (i, (v2, v1, ex)) in cases.into_iter().enumerate() {
241 let bz = msgpack_serialize(&v2, true).unwrap();
242 let res = msgpack_deserialize::<version1::Foo>(&bz);
243 match (res, ex) {
244 (Ok(v), None) => {
245 assert_eq!(v, v1);
246 }
247 (Ok(_), Some(ex)) => panic!("case {i} expected to fail with {ex}"),
248 (Err(e), None) => panic!("case {i} expected to pass; got {e}"),
249 (Err(e), Some(ex)) => {
250 let e = e.to_string();
251 if !e.contains(ex) {
252 panic!("case {i} error expected to contain {ex}; got {e}")
253 }
254 }
255 }
256 }
257 }
258
259 #[test]
261 fn msgpack_repr_enum_of_structs() {
262 use rmpv::Value; let value = ValueOrArray::HeapArray(HeapArray {
265 pointer: brillig::MemoryAddress::Relative(0),
266 size: SemiFlattenedLength(3),
267 });
268 let bz = msgpack_serialize(&value, false).unwrap();
269 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); let Value::Map(fields) = msg else {
272 panic!("expected Map: {msg:?}");
273 };
274 assert_eq!(fields.len(), 1);
275 let Value::String(key) = &fields[0].0 else {
276 panic!("expected String key: {fields:?}");
277 };
278 assert_eq!(key.as_str(), Some("HeapArray"));
279 }
280
281 #[test]
283 fn msgpack_repr_enum_of_unit_structs() {
284 let value = IntegerBitSize::U1;
285 let bz = msgpack_serialize(&value, false).unwrap();
286 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); assert_eq!(msg.as_str(), Some("U1"));
289 }
290
291 #[test]
293 fn msgpack_repr_enum_of_mixed() {
294 let value = vec![BitSize::Field, BitSize::Integer(IntegerBitSize::U64)];
295 let bz = msgpack_serialize(&value, false).unwrap();
296 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); assert_eq!(format!("{msg}"), r#"["Field", {"Integer": "U64"}]"#);
299 }
300
301 #[test]
303 fn msgpack_repr_newtype() {
304 use rmpv::Value; let value = Witness(1);
307 let bz = msgpack_serialize(&value, false).unwrap();
308 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); assert!(matches!(msg, Value::Integer(_)));
311 }
312
313 #[test]
314 fn format_from_str() {
315 assert_eq!(Format::from_str("msgpack-compact").unwrap(), Format::MsgpackCompact);
316 }
317}