1use msgpack_tagged::{
4 EncodingStrategy, MsgpackTagged, Serializer as TaggedSerializer, TagRegistry,
5 msgpack_tagged_deserialize,
6};
7use num_enum::{IntoPrimitive, TryFromPrimitive};
8use serde::{Deserialize, Serialize};
9use std::str::FromStr;
10use strum_macros::EnumString;
11
12const FORMAT_ENV_VAR: &str = "NOIR_SERIALIZATION_FORMAT";
13
14#[derive(Debug, Default, Clone, Copy, IntoPrimitive, TryFromPrimitive, EnumString, PartialEq, Eq)]
16#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
17#[strum(serialize_all = "kebab-case")]
18#[repr(u8)]
19pub enum Format {
20 Msgpack = 2,
22 #[default]
24 MsgpackCompact = 3,
25 MsgpackTagged = 4,
30}
31
32impl Format {
33 pub fn from_env() -> Result<Option<Self>, String> {
40 let Ok(format) = std::env::var(FORMAT_ENV_VAR) else {
41 return Ok(None);
42 };
43 Self::from_str(&format)
44 .map(Some)
45 .map_err(|e| format!("unknown format '{format}' in {FORMAT_ENV_VAR}: {e}"))
46 }
47}
48
49pub(crate) fn msgpack_serialize<T: Serialize>(
59 value: &T,
60 compact: bool,
61) -> std::io::Result<Vec<u8>> {
62 let mut buf = Vec::new();
68
69 let serializer = rmp_serde::Serializer::new(&mut buf)
70 .with_bytes(rmp_serde::config::BytesMode::ForceIterables);
71
72 let result = if compact {
73 value.serialize(&mut serializer.with_struct_tuple())
74 } else {
75 value.serialize(&mut serializer.with_struct_map())
76 };
77
78 match result {
79 Ok(()) => Ok(buf),
80 Err(e) => Err(std::io::Error::other(e)),
81 }
82}
83
84pub(crate) fn msgpack_deserialize<T: for<'a> Deserialize<'a>>(buf: &[u8]) -> std::io::Result<T> {
86 rmp_serde::from_slice(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
87}
88
89pub(crate) fn deserialize_any_format<T>(buf: &[u8]) -> std::io::Result<T>
91where
92 T: for<'a> Deserialize<'a> + MsgpackTagged,
93{
94 let Some(format_byte) = buf.first() else {
95 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "empty buffer"));
96 };
97
98 match Format::try_from(*format_byte) {
99 Ok(Format::Msgpack) | Ok(Format::MsgpackCompact) => msgpack_deserialize(&buf[1..]),
100 Ok(Format::MsgpackTagged) => msgpack_tagged_deserialize(&buf[1..]),
101 Err(msg) => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, msg.to_string())),
102 }
103}
104
105pub(crate) fn serialize_with_format<T>(value: &T, format: Format) -> std::io::Result<Vec<u8>>
106where
107 T: Serialize + MsgpackTagged,
108{
109 let mut buf = match format {
111 Format::Msgpack => msgpack_serialize(value, false)?,
112 Format::MsgpackCompact => msgpack_serialize(value, true)?,
113 Format::MsgpackTagged => msgpack_tagged_serialize_acir(value)?,
114 };
115 let mut res = vec![format.into()];
116 res.append(&mut buf);
117 Ok(res)
118}
119
120pub(crate) fn msgpack_tagged_serialize_acir<T>(value: &T) -> std::io::Result<Vec<u8>>
139where
140 T: ?Sized + Serialize + MsgpackTagged,
141{
142 let registry = TagRegistry::from_type::<T>();
143 let mut buf = Vec::new();
144 let mut serializer = TaggedSerializer::new(&mut buf, ®istry)
145 .with_default_strategy(EncodingStrategy::Array)
146 .with_strategy_for_name("Program", EncodingStrategy::Tagged)
147 .with_strategy_for_name("Circuit", EncodingStrategy::Tagged)
148 .with_strategy_for_name("BrilligBytecode", EncodingStrategy::Tagged);
149 value.serialize(&mut serializer).map_err(std::io::Error::other)?;
150 Ok(buf)
151}
152
153#[cfg(test)]
154mod tests {
155 use brillig::{BitSize, HeapArray, IntegerBitSize, ValueOrArray, lengths::SemiFlattenedLength};
156 use std::str::FromStr;
157
158 use crate::{
159 native_types::Witness,
160 serialization::{Format, msgpack_deserialize, msgpack_serialize},
161 };
162
163 mod version1 {
164 use serde::{Deserialize, Serialize};
165
166 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
167 pub(crate) enum Foo {
168 Case0 { d: u32 },
169 Case1 { a: u64, b: bool },
170 Case2 { a: i32 },
171 Case3 { a: bool },
172 Case4 { a: Box<Foo> },
173 Case5 { a: u32, b: Option<u32> },
174 }
175 }
176
177 mod version2 {
178 use serde::{Deserialize, Serialize};
179
180 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
182 pub(crate) enum Foo {
183 Case1 {
187 a: u64,
188 b: bool,
189 },
190 Case2 {
192 b: String,
193 a: i32,
194 },
195 Case3 {
197 a: bool,
198 b: String,
199 },
200 Case5 {
202 a: u32,
203 },
204 Case4 {
206 #[serde(rename = "a")]
207 c: Box<Foo>,
208 },
209 Case6 {
211 b: i64,
212 },
213 Case7 {
215 c: bool,
216 },
217 }
218 }
219
220 #[test]
226 fn msgpack_serialize_backwards_compatibility() {
227 let cases = vec![
228 (version2::Foo::Case1 { b: true, a: 1 }, version1::Foo::Case1 { b: true, a: 1 }),
229 (version2::Foo::Case2 { b: "prefix".into(), a: 2 }, version1::Foo::Case2 { a: 2 }),
230 (
231 version2::Foo::Case3 { a: true, b: "suffix".into() },
232 version1::Foo::Case3 { a: true },
233 ),
234 (
235 version2::Foo::Case4 { c: Box::new(version2::Foo::Case1 { a: 4, b: false }) },
236 version1::Foo::Case4 { a: Box::new(version1::Foo::Case1 { a: 4, b: false }) },
237 ),
238 (version2::Foo::Case5 { a: 5 }, version1::Foo::Case5 { a: 5, b: None }),
239 ];
240
241 for (i, (v2, v1)) in cases.into_iter().enumerate() {
242 let bz = msgpack_serialize(&v2, false).unwrap();
243 let v = msgpack_deserialize::<version1::Foo>(&bz)
244 .unwrap_or_else(|e| panic!("case {i} failed: {e}"));
245 assert_eq!(v, v1);
246 }
247 }
248
249 #[test]
258 fn msgpack_serialize_compact_backwards_compatibility() {
259 let cases = vec![
260 (version2::Foo::Case1 { b: true, a: 1 }, version1::Foo::Case1 { b: true, a: 1 }, None),
261 (
262 version2::Foo::Case2 { b: "prefix".into(), a: 2 },
263 version1::Foo::Case2 { a: 2 },
264 Some("wrong msgpack marker FixStr(6)"),
265 ),
266 (
267 version2::Foo::Case3 { a: true, b: "suffix".into() },
268 version1::Foo::Case3 { a: true },
269 Some("array had incorrect length, expected 1"),
270 ),
271 (
272 version2::Foo::Case4 { c: Box::new(version2::Foo::Case1 { a: 4, b: false }) },
273 version1::Foo::Case4 { a: Box::new(version1::Foo::Case1 { a: 4, b: false }) },
274 None,
275 ),
276 (
277 version2::Foo::Case5 { a: 5 },
278 version1::Foo::Case5 { a: 5, b: None },
279 Some("invalid length 1, expected struct variant Foo::Case5 with 2 elements"),
280 ),
281 ];
282
283 for (i, (v2, v1, ex)) in cases.into_iter().enumerate() {
284 let bz = msgpack_serialize(&v2, true).unwrap();
285 let res = msgpack_deserialize::<version1::Foo>(&bz);
286 match (res, ex) {
287 (Ok(v), None) => {
288 assert_eq!(v, v1);
289 }
290 (Ok(_), Some(ex)) => panic!("case {i} expected to fail with {ex}"),
291 (Err(e), None) => panic!("case {i} expected to pass; got {e}"),
292 (Err(e), Some(ex)) => {
293 let e = e.to_string();
294 if !e.contains(ex) {
295 panic!("case {i} error expected to contain {ex}; got {e}")
296 }
297 }
298 }
299 }
300 }
301
302 #[test]
304 fn msgpack_repr_enum_of_structs() {
305 use rmpv::Value; let value = ValueOrArray::HeapArray(HeapArray {
308 pointer: brillig::MemoryAddress::Relative(0),
309 size: SemiFlattenedLength(3),
310 });
311 let bz = msgpack_serialize(&value, false).unwrap();
312 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); let Value::Map(fields) = msg else {
315 panic!("expected Map: {msg:?}");
316 };
317 assert_eq!(fields.len(), 1);
318 let Value::String(key) = &fields[0].0 else {
319 panic!("expected String key: {fields:?}");
320 };
321 assert_eq!(key.as_str(), Some("HeapArray"));
322 }
323
324 #[test]
326 fn msgpack_repr_enum_of_unit_structs() {
327 let value = IntegerBitSize::U1;
328 let bz = msgpack_serialize(&value, false).unwrap();
329 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); assert_eq!(msg.as_str(), Some("U1"));
332 }
333
334 #[test]
336 fn msgpack_repr_enum_of_mixed() {
337 let value = vec![BitSize::Field, BitSize::Integer(IntegerBitSize::U64)];
338 let bz = msgpack_serialize(&value, false).unwrap();
339 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); assert_eq!(format!("{msg}"), r#"["Field", {"Integer": "U64"}]"#);
342 }
343
344 #[test]
346 fn msgpack_repr_newtype() {
347 use rmpv::Value; let value = Witness(1);
350 let bz = msgpack_serialize(&value, false).unwrap();
351 let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap(); assert!(matches!(msg, Value::Integer(_)));
354 }
355
356 mod msgpack_repr_field_element {
366 use super::super::{Format, deserialize_any_format, serialize_with_format};
367 use acir_field::{AcirField, FieldElement};
368 use rmpv::Value;
369
370 fn encoded_msgpack<F: Into<Format>>(value: &FieldElement, format: F) -> Vec<u8> {
374 let mut bz = serialize_with_format(value, format.into()).unwrap();
375 bz.remove(0);
376 bz
377 }
378
379 fn assert_bin(bz: &[u8]) -> Vec<u8> {
382 let msg = rmpv::decode::read_value::<&[u8]>(&mut &bz[..]).unwrap();
383 match msg {
384 Value::Binary(v) => v,
385 other => panic!("expected msgpack Binary for FieldElement; got {other:?}"),
386 }
387 }
388
389 #[test]
390 fn msgpack_tagged_emits_bin() {
391 let v = FieldElement::from(1u128);
392 let bytes = assert_bin(&encoded_msgpack(&v, Format::MsgpackTagged));
393 assert_eq!(bytes.len(), v.to_be_bytes().len());
394 assert_eq!(bytes, v.to_be_bytes());
395 }
396
397 #[test]
398 fn msgpack_compact_emits_bin() {
399 let v = FieldElement::from(42u128);
408 let bytes = assert_bin(&encoded_msgpack(&v, Format::MsgpackCompact));
409 assert_eq!(bytes, v.to_be_bytes());
410 }
411
412 #[test]
413 fn msgpack_named_emits_bin() {
414 let v = FieldElement::from(99u128);
416 let bytes = assert_bin(&encoded_msgpack(&v, Format::Msgpack));
417 assert_eq!(bytes, v.to_be_bytes());
418 }
419
420 #[test]
422 fn round_trips_under_all_formats() {
423 for &format in &[Format::Msgpack, Format::MsgpackCompact, Format::MsgpackTagged] {
424 let v = FieldElement::from(7u128);
425 let bytes = serialize_with_format(&v, format).unwrap();
426 let decoded: FieldElement = deserialize_any_format(&bytes).unwrap();
427 assert_eq!(decoded, v, "round-trip failed under {format:?}");
428 }
429 }
430 }
431
432 #[test]
433 fn format_from_str() {
434 assert_eq!(Format::from_str("msgpack-compact").unwrap(), Format::MsgpackCompact);
435 assert_eq!(Format::from_str("msgpack-tagged").unwrap(), Format::MsgpackTagged);
436 }
437
438 #[test]
444 fn msgpack_tagged_format_roundtrip() {
445 use msgpack_tagged::MsgpackTagged;
446
447 #[derive(serde::Serialize, serde::Deserialize, MsgpackTagged, PartialEq, Eq, Debug)]
448 struct Foo {
449 #[tag(0)]
450 a: u32,
451 #[tag(1)]
452 b: bool,
453 }
454
455 let value = Foo { a: 7, b: true };
456 let bytes = super::serialize_with_format(&value, Format::MsgpackTagged).unwrap();
457 assert_eq!(bytes[0], Format::MsgpackTagged as u8);
458 let decoded: Foo = super::deserialize_any_format(&bytes).unwrap();
459 assert_eq!(decoded, value);
460 }
461
462 #[test]
470 fn msgpack_tagged_acir_policy_program_tagged_expression_array() {
471 use crate::circuit::{Circuit, Opcode, Program};
472 use crate::native_types::Expression;
473 use acir_field::FieldElement;
474 use rmpv::Value;
475
476 let expr: Expression<FieldElement> = Expression {
477 mul_terms: vec![],
478 linear_combinations: vec![],
479 q_c: FieldElement::from(1u128),
480 };
481 let circuit: Circuit<FieldElement> = Circuit {
482 function_name: "main".to_string(),
483 opcodes: vec![Opcode::AssertZero(expr)],
484 ..Circuit::default()
485 };
486 let program =
487 Program::<FieldElement> { functions: vec![circuit], unconstrained_functions: vec![] };
488
489 let bytes = super::msgpack_tagged_serialize_acir(&program).expect("encode succeeds");
490 let value = rmpv::decode::read_value(&mut bytes.as_slice()).expect("valid msgpack");
491
492 let Value::Map(program_entries) = &value else {
496 panic!("expected fixmap for Program under the ACIR policy, got {value:?}");
497 };
498 assert_eq!(program_entries.len(), 2);
501 assert!(program_entries.iter().all(|(k, _)| matches!(k, Value::Integer(_))));
502
503 let functions = program_entries
505 .iter()
506 .find(|(k, _)| k.as_u64() == Some(0))
507 .map(|(_, v)| v)
508 .expect("functions tag present");
509 let Value::Array(functions) = functions else {
510 panic!("functions field should be a msgpack array, got {functions:?}");
511 };
512 let Value::Map(circuit_entries) = &functions[0] else {
513 panic!("expected fixmap for Circuit, got {:?}", functions[0]);
514 };
515 assert!(circuit_entries.iter().all(|(k, _)| matches!(k, Value::Integer(_))));
516
517 let opcodes_value = circuit_entries
521 .iter()
522 .find(|(k, _)| k.as_u64() == Some(1))
523 .map(|(_, v)| v)
524 .expect("opcodes tag present on Circuit wire");
525 let Value::Array(opcodes) = opcodes_value else {
526 panic!("opcodes should be a msgpack array, got {opcodes_value:?}");
527 };
528 let Value::Map(opcode) = &opcodes[0] else {
532 panic!("opcode should be a 1-entry map, got {:?}", opcodes[0]);
533 };
534 assert_eq!(opcode.len(), 1);
535 let expression_value = &opcode[0].1;
536 assert!(
537 matches!(expression_value, Value::Array(_)),
538 "expected fixarray for nested Expression under default Array policy, got \
539 {expression_value:?}",
540 );
541 }
542
543 #[test]
547 fn msgpack_tagged_acir_policy_program_roundtrips_through_format() {
548 use crate::circuit::Program;
549 use acir_field::FieldElement;
550
551 let program = Program::<FieldElement>::default();
552 let bytes = super::serialize_with_format(&program, Format::MsgpackTagged).unwrap();
553 let decoded: Program<FieldElement> = super::deserialize_any_format(&bytes).unwrap();
554 assert_eq!(decoded, program);
555 }
556}