1pub mod black_box_functions;
4pub mod brillig;
5pub mod opcodes;
6
7use crate::{
8 SerializationFormat,
9 circuit::opcodes::display_opcode,
10 native_types::{Expression, Witness},
11 serialization::{self, deserialize_any_format, serialize_with_format},
12};
13use acir_field::AcirField;
14use msgpack_tagged::MsgpackTagged;
15pub use opcodes::Opcode;
16use thiserror::Error;
17
18use std::{collections::HashMap, io::prelude::*, num::ParseIntError, str::FromStr};
19
20use base64::Engine;
21use flate2::Compression;
22use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as DeserializationError};
23
24use std::collections::BTreeSet;
25
26use self::{brillig::BrilligBytecode, opcodes::BlockId};
27
28#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash, MsgpackTagged)]
31#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
32#[tagged(allow_unknown_tags)]
33pub struct Program<F: AcirField> {
34 #[tag(0)]
35 pub functions: Vec<Circuit<F>>,
36 #[tag(1)]
37 pub unconstrained_functions: Vec<BrilligBytecode<F>>,
38}
39
40#[derive(Clone, PartialEq, Eq, Default, Hash, Serialize, Deserialize, MsgpackTagged)]
43#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
44#[tagged(allow_unknown_tags)]
45pub struct Circuit<F: AcirField> {
46 #[tag(0)]
48 pub function_name: String,
49 #[tag(1)]
54 pub opcodes: Vec<Opcode<F>>,
55 #[tag(2)]
57 pub private_parameters: BTreeSet<Witness>,
58 #[tag(3)]
63 pub public_parameters: PublicInputs,
64 #[tag(4)]
66 pub return_values: PublicInputs,
67 #[tag(5)]
75 pub assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash)]
80#[derive(Serialize, Deserialize, MsgpackTagged)]
81#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
82pub enum ExpressionOrMemory<F> {
83 #[tag(0)]
84 Expression(Expression<F>),
85 #[tag(1)]
86 Memory(BlockId),
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Hash)]
92#[derive(Serialize, Deserialize, MsgpackTagged)]
93#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
94pub struct AssertionPayload<F> {
95 #[tag(0)]
98 pub error_selector: u64,
99 #[tag(1)]
106 pub payload: Vec<ExpressionOrMemory<F>>,
107}
108
109#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
111pub struct ErrorSelector(u64);
112
113impl ErrorSelector {
114 pub fn new(integer: u64) -> Self {
115 ErrorSelector(integer)
116 }
117
118 pub fn as_u64(&self) -> u64 {
119 self.0
120 }
121}
122
123impl Serialize for ErrorSelector {
124 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
125 where
126 S: Serializer,
127 {
128 self.0.to_string().serialize(serializer)
129 }
130}
131
132impl<'de> Deserialize<'de> for ErrorSelector {
133 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134 where
135 D: Deserializer<'de>,
136 {
137 let s: String = Deserialize::deserialize(deserializer)?;
138 let as_u64 = s.parse().map_err(serde::de::Error::custom)?;
139 Ok(ErrorSelector(as_u64))
140 }
141}
142
143#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
144#[derive(Serialize, Deserialize, MsgpackTagged)]
145#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
146pub enum OpcodeLocation {
149 #[tag(0)]
150 Acir(usize),
151 #[tag(1)]
154 Brillig {
155 #[tag(0)]
156 acir_index: usize,
157 #[tag(1)]
158 brillig_index: usize,
159 },
160}
161
162#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
165pub struct AcirOpcodeLocation(usize);
166impl std::fmt::Display for AcirOpcodeLocation {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 write!(f, "{}", self.0)
169 }
170}
171
172impl AcirOpcodeLocation {
173 pub fn new(index: usize) -> Self {
174 AcirOpcodeLocation(index)
175 }
176 pub fn index(&self) -> usize {
177 self.0
178 }
179}
180#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
183pub struct BrilligOpcodeLocation(usize);
184
185impl BrilligOpcodeLocation {
186 pub fn new(index: usize) -> Self {
187 BrilligOpcodeLocation(index)
188 }
189 pub fn index(&self) -> usize {
190 self.0
191 }
192}
193
194impl OpcodeLocation {
195 pub fn to_brillig_location(self) -> Option<BrilligOpcodeLocation> {
199 match self {
200 OpcodeLocation::Brillig { brillig_index, .. } => {
201 Some(BrilligOpcodeLocation::new(brillig_index))
202 }
203 OpcodeLocation::Acir(_) => None,
204 }
205 }
206}
207
208impl std::fmt::Display for OpcodeLocation {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 OpcodeLocation::Acir(index) => write!(f, "{index}"),
212 OpcodeLocation::Brillig { acir_index, brillig_index } => {
213 write!(f, "{acir_index}.{brillig_index}")
214 }
215 }
216 }
217}
218
219#[derive(Error, Debug)]
220pub enum OpcodeLocationFromStrError {
221 #[error("Invalid opcode location string: {0}")]
222 InvalidOpcodeLocationString(String),
223}
224
225impl FromStr for OpcodeLocation {
228 type Err = OpcodeLocationFromStrError;
229 fn from_str(s: &str) -> Result<Self, Self::Err> {
230 let parts: Vec<_> = s.split('.').collect();
231
232 if parts.is_empty() || parts.len() > 2 {
233 return Err(OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()));
234 }
235
236 fn parse_components(parts: &[&str]) -> Result<OpcodeLocation, ParseIntError> {
237 match parts.len() {
238 1 => {
239 let index = parts[0].parse()?;
240 Ok(OpcodeLocation::Acir(index))
241 }
242 2 => {
243 let acir_index = parts[0].parse()?;
244 let brillig_index = parts[1].parse()?;
245 Ok(OpcodeLocation::Brillig { acir_index, brillig_index })
246 }
247 _ => unreachable!("`OpcodeLocation` has too many components"),
248 }
249 }
250
251 parse_components(&parts)
252 .map_err(|_| OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()))
253 }
254}
255
256impl std::fmt::Display for BrilligOpcodeLocation {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 let index = self.0;
259 write!(f, "{index}")
260 }
261}
262
263impl<F: AcirField> Circuit<F> {
264 pub fn circuit_arguments(&self) -> BTreeSet<Witness> {
266 self.private_parameters.union(&self.public_parameters.0).copied().collect()
267 }
268
269 pub fn public_inputs(&self) -> PublicInputs {
272 let public_inputs =
273 self.public_parameters.0.union(&self.return_values.0).copied().collect();
274 PublicInputs(public_inputs)
275 }
276}
277
278impl<F: Serialize + AcirField + MsgpackTagged> Program<F> {
279 fn compress(buf: &[u8]) -> std::io::Result<Vec<u8>> {
281 let mut compressed: Vec<u8> = Vec::new();
282 let mut encoder = flate2::write::GzEncoder::new(&mut compressed, Compression::default());
284 encoder.write_all(buf)?;
285 encoder.finish()?;
286 Ok(compressed)
287 }
288
289 pub fn serialize_program_with_format(program: &Self, format: serialization::Format) -> Vec<u8> {
291 let program_bytes =
292 serialize_with_format(program, format).expect("expected circuit to be serializable");
293 Self::compress(&program_bytes).expect("expected circuit to compress")
294 }
295
296 pub fn serialize_program(program: &Self) -> Vec<u8> {
298 let format = SerializationFormat::from_env().expect("invalid format");
299 Self::serialize_program_with_format(program, format.unwrap_or_default())
300 }
301
302 pub fn serialize_program_base64<S>(program: &Self, s: S) -> Result<S::Ok, S::Error>
304 where
305 S: Serializer,
306 {
307 let program_bytes = Program::serialize_program(program);
308 let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes);
309 s.serialize_str(&encoded_b64)
310 }
311}
312
313impl<F: AcirField + for<'a> Deserialize<'a> + MsgpackTagged> Program<F> {
314 fn read<R: Read>(reader: R) -> std::io::Result<Self> {
316 let mut gz_decoder = flate2::read::GzDecoder::new(reader);
317 let mut buf = Vec::new();
318 gz_decoder.read_to_end(&mut buf)?;
319 let program = deserialize_any_format(&buf)?;
320 Ok(program)
321 }
322
323 pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
325 Program::read(serialized_circuit)
326 }
327
328 pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Self, D::Error>
330 where
331 D: Deserializer<'de>,
332 {
333 let bytecode_b64: String = Deserialize::deserialize(deserializer)?;
334 let program_bytes = base64::engine::general_purpose::STANDARD
335 .decode(bytecode_b64)
336 .map_err(D::Error::custom)?;
337 let circuit = Self::deserialize_program(&program_bytes).map_err(D::Error::custom)?;
338 Ok(circuit)
339 }
340}
341
342impl<F: AcirField> std::fmt::Display for Circuit<F> {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 display_circuit(self, None, f)
345 }
346}
347
348pub fn display_circuit<F: AcirField>(
349 circuit: &Circuit<F>,
350 error_types: Option<&HashMap<ErrorSelector, String>>,
351 f: &mut std::fmt::Formatter<'_>,
352) -> std::fmt::Result {
353 let write_witness_indices =
354 |f: &mut std::fmt::Formatter<'_>, indices: &[u32]| -> Result<(), std::fmt::Error> {
355 write!(f, "[")?;
356 for (index, witness_index) in indices.iter().enumerate() {
357 write!(f, "w{witness_index}")?;
358 if index != indices.len() - 1 {
359 write!(f, ", ")?;
360 }
361 }
362 writeln!(f, "]")
363 };
364
365 write!(f, "private parameters: ")?;
366 write_witness_indices(
367 f,
368 &circuit
369 .private_parameters
370 .iter()
371 .map(|witness| witness.witness_index())
372 .collect::<Vec<_>>(),
373 )?;
374
375 write!(f, "public parameters: ")?;
376 write_witness_indices(f, &circuit.public_parameters.indices())?;
377
378 write!(f, "return values: ")?;
379 write_witness_indices(f, &circuit.return_values.indices())?;
380
381 let assert_messages_by_opcode_location =
382 circuit.assert_messages.iter().cloned().collect::<HashMap<_, _>>();
383
384 for (index, opcode) in circuit.opcodes.iter().enumerate() {
385 display_opcode(opcode, Some(&circuit.return_values), f)?;
386
387 if let Some(error_types) = error_types {
388 let location = OpcodeLocation::Acir(index);
389 if let Some(payload) = assert_messages_by_opcode_location.get(&location)
390 && let Some(message) = error_types.get(&ErrorSelector::new(payload.error_selector))
391 {
392 write!(f, " // {message}")?;
393 }
394 }
395 writeln!(f)?;
396 }
397 Ok(())
398}
399
400impl<F: AcirField> std::fmt::Debug for Circuit<F> {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 std::fmt::Display::fmt(self, f)
403 }
404}
405
406impl<F: AcirField> std::fmt::Display for Program<F> {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 display_program(self, None, f)
409 }
410}
411
412pub fn display_program<F: AcirField>(
413 program: &Program<F>,
414 error_types: Option<&HashMap<ErrorSelector, String>>,
415 f: &mut std::fmt::Formatter<'_>,
416) -> std::fmt::Result {
417 for (func_index, function) in program.functions.iter().enumerate() {
418 writeln!(f, "func {func_index}")?;
419 display_circuit(function, error_types, f)?;
420 writeln!(f)?;
421 }
422 for (func_index, function) in program.unconstrained_functions.iter().enumerate() {
423 writeln!(f, "unconstrained func {func_index}: {}", function.function_name)?;
424 let width = function.bytecode.len().to_string().len();
425 for (index, opcode) in function.bytecode.iter().enumerate() {
426 write!(f, "{index:>width$}: {opcode}")?;
427
428 if let ::brillig::Opcode::IndirectConst { value, .. } = opcode
429 && let Some(value) = value.try_to_u64()
430 && let Some(message) =
431 error_types.and_then(|error_types| error_types.get(&ErrorSelector::new(value)))
432 {
433 write!(f, " // {message:?}")?;
434 }
435
436 writeln!(f)?;
437 }
438 }
439 Ok(())
440}
441
442impl<F: AcirField> std::fmt::Debug for Program<F> {
443 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444 std::fmt::Display::fmt(self, f)
445 }
446}
447
448#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, Hash, MsgpackTagged)]
449#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
450pub struct PublicInputs(pub BTreeSet<Witness>);
451
452impl PublicInputs {
453 pub fn indices(&self) -> Vec<u32> {
455 self.0.iter().map(|witness| witness.witness_index()).collect()
456 }
457
458 pub fn contains(&self, index: usize) -> bool {
459 self.0.contains(&Witness::new(index as u32))
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::{Circuit, Compression};
466 use crate::circuit::Program;
467 use acir_field::{AcirField, FieldElement};
468 use msgpack_tagged::MsgpackTagged;
469 use serde::{Deserialize, Serialize};
470
471 #[test]
472 fn serialization_roundtrip() {
473 let src = "
474 private parameters: []
475 public parameters: [w2, w12]
476 return values: [w4, w12]
477 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
478 BLACKBOX::RANGE input: w1, bits: 8
479 ";
480 let circuit = Circuit::from_str(src).unwrap();
481 let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
482
483 fn read_write<F: Serialize + for<'a> Deserialize<'a> + AcirField + MsgpackTagged>(
484 program: Program<F>,
485 ) -> (Program<F>, Program<F>) {
486 let bytes = Program::serialize_program(&program);
487 let got_program = Program::deserialize_program(&bytes).unwrap();
488 (program, got_program)
489 }
490
491 let (circ, got_circ) = read_write(program);
492 assert_eq!(circ, got_circ);
493 }
494
495 #[test]
496 fn test_serialize() {
497 let src = "
498 private parameters: []
499 public parameters: [w2]
500 return values: [w2]
501 ASSERT 0 = 8
502 BLACKBOX::RANGE input: w1, bits: 8
503 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
504 BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
505 ";
506 let circuit = Circuit::from_str(src).unwrap();
507 let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
508
509 let json = serde_json::to_string_pretty(&program).unwrap();
510
511 let deserialized = serde_json::from_str(&json).unwrap();
512 assert_eq!(program, deserialized);
513 }
514
515 #[test]
516 fn does_not_panic_on_invalid_circuit() {
517 use std::io::Write;
518
519 let bad_circuit = "I'm not an ACIR circuit".as_bytes();
520
521 let mut zipped_bad_circuit = Vec::new();
523 let mut encoder =
524 flate2::write::GzEncoder::new(&mut zipped_bad_circuit, Compression::default());
525 encoder.write_all(bad_circuit).unwrap();
526 encoder.finish().unwrap();
527
528 let deserialization_result: Result<Program<FieldElement>, _> =
529 Program::deserialize_program(&zipped_bad_circuit);
530 assert!(deserialization_result.is_err());
531 }
532
533 #[test]
534 fn circuit_display_snapshot() {
535 let src = "
536 private parameters: []
537 public parameters: [w2]
538 return values: [w2]
539 ASSERT 0 = 2*w1 + 8
540 BLACKBOX::RANGE input: w1, bits: 8
541 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
542 BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
543 ";
544 let circuit = Circuit::from_str(src).unwrap();
545
546 insta::assert_snapshot!(
548 circuit.to_string(),
549 @r"
550 private parameters: []
551 public parameters: [w2]
552 return values: [w2]
553 ASSERT 0 = 2*w1 + 8
554 BLACKBOX::RANGE input: w1, bits: 8
555 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
556 BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
557 "
558 );
559 }
560
561 #[test]
569 fn ensure_program_is_msgpack_tagged() {
570 fn assert_impl<T: MsgpackTagged>() {}
571 assert_impl::<Program<FieldElement>>();
572 }
573
574 mod props {
576 use acir_field::FieldElement;
577 use msgpack_tagged::MsgpackTagged;
578 use proptest::prelude::*;
579 use proptest::test_runner::{TestCaseResult, TestRunner};
580
581 use crate::circuit::Program;
582 use crate::native_types::{WitnessMap, WitnessStack};
583 use crate::serialization::*;
584
585 const MAX_SIZE_RANGE: usize = 5;
593 const SIZE_RANGE_KEY: &str = "PROPTEST_MAX_DEFAULT_SIZE_RANGE";
594
595 acir_field::field_wrapper!(TestField, FieldElement);
599
600 impl Arbitrary for TestField {
601 type Parameters = ();
602 type Strategy = BoxedStrategy<Self>;
603
604 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
605 any::<u128>().prop_map(|v| Self(FieldElement::from(v))).boxed()
606 }
607 }
608
609 impl MsgpackTagged for TestField {
610 const TAGGED: msgpack_tagged::Tagged = msgpack_tagged::Tagged::empty_product();
611 fn register_into(_reg: &mut msgpack_tagged::TagRegistry) {}
612 }
613
614 #[allow(unsafe_code)]
616 fn run_with_max_size_range<T, F>(cases: u32, f: F)
617 where
618 T: Arbitrary,
619 F: Fn(T) -> TestCaseResult,
620 {
621 let orig_size_range = std::env::var(SIZE_RANGE_KEY).ok();
622 if orig_size_range.is_none() {
624 unsafe {
625 std::env::set_var(SIZE_RANGE_KEY, MAX_SIZE_RANGE.to_string());
626 }
627 }
628
629 let mut runner = TestRunner::new(ProptestConfig { cases, ..Default::default() });
630 let result = runner.run(&any::<T>(), f);
631
632 unsafe {
634 std::env::set_var(SIZE_RANGE_KEY, orig_size_range.unwrap_or_default());
635 }
636
637 result.unwrap();
638 }
639
640 #[test]
645 fn prop_program_arb_roundtrip() {
646 run_with_max_size_range(100, |(program, format): (Program<TestField>, Format)| {
647 let bz = serialize_with_format(&program, format)?;
648 let de = deserialize_any_format(&bz)?;
649 prop_assert_eq!(program, de);
650 Ok(())
651 });
652 }
653
654 #[test]
655 fn prop_program_roundtrip() {
656 run_with_max_size_range(10, |program: Program<TestField>| {
657 let bz = Program::serialize_program(&program);
658 let de = Program::deserialize_program(&bz)?;
659 prop_assert_eq!(program, de);
660 Ok(())
661 });
662 }
663
664 #[test]
665 fn prop_witness_stack_arb_roundtrip() {
666 run_with_max_size_range(10, |(witness, format): (WitnessStack<TestField>, Format)| {
667 let bz = serialize_with_format(&witness, format)?;
668 let de = deserialize_any_format(&bz)?;
669 prop_assert_eq!(witness, de);
670 Ok(())
671 });
672 }
673
674 #[test]
675 fn prop_witness_stack_roundtrip() {
676 run_with_max_size_range(10, |witness: WitnessStack<TestField>| {
677 let bz = witness.serialize()?;
678 let de = WitnessStack::deserialize(bz.as_slice())?;
679 prop_assert_eq!(witness, de);
680 Ok(())
681 });
682 }
683
684 #[test]
685 fn prop_witness_map_arb_roundtrip() {
686 run_with_max_size_range(10, |(witness, format): (WitnessMap<TestField>, Format)| {
687 let bz = serialize_with_format(&witness, format)?;
688 let de = deserialize_any_format(&bz)?;
689 prop_assert_eq!(witness, de);
690 Ok(())
691 });
692 }
693
694 #[test]
695 fn prop_witness_map_roundtrip() {
696 run_with_max_size_range(10, |witness: WitnessMap<TestField>| {
697 let bz = witness.serialize()?;
698 let de = WitnessMap::deserialize(bz.as_slice())?;
699 prop_assert_eq!(witness, de);
700 Ok(())
701 });
702 }
703 }
704}