1pub mod black_box_functions;
4pub mod brillig;
5pub mod opcodes;
6
7use crate::{
8 SerializationFormat,
9 circuit::opcodes::display_opcode,
10 native_types::{Expression, Witness},
11 serialization::{self, deserialize_any_format, serialize_with_format},
12};
13use acir_field::AcirField;
14pub use opcodes::Opcode;
15use thiserror::Error;
16
17use std::{collections::HashMap, io::prelude::*, num::ParseIntError, str::FromStr};
18
19use base64::Engine;
20use flate2::Compression;
21use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as DeserializationError};
22
23use std::collections::BTreeSet;
24
25use self::{brillig::BrilligBytecode, opcodes::BlockId};
26
27#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
30#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
31pub struct Program<F: AcirField> {
32 pub functions: Vec<Circuit<F>>,
33 pub unconstrained_functions: Vec<BrilligBytecode<F>>,
34}
35
36#[derive(Clone, PartialEq, Eq, Default, Hash)]
39#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
40pub struct Circuit<F: AcirField> {
41 pub function_name: String,
43 pub opcodes: Vec<Opcode<F>>,
48 pub private_parameters: BTreeSet<Witness>,
50 pub public_parameters: PublicInputs,
55 pub return_values: PublicInputs,
57 pub assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,
65}
66
67#[derive(Serialize, Deserialize)]
71#[serde(rename = "Circuit")]
72struct CircuitWire<F: AcirField> {
73 #[serde(default)]
74 function_name: String,
75 current_witness_index: u32,
76 opcodes: Vec<Opcode<F>>,
77 private_parameters: BTreeSet<Witness>,
78 public_parameters: PublicInputs,
79 return_values: PublicInputs,
80 assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,
81}
82
83impl<F: AcirField + Serialize> Serialize for Circuit<F> {
84 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
85 CircuitWire {
86 function_name: self.function_name.clone(),
87 current_witness_index: 0,
88 opcodes: self.opcodes.clone(),
89 private_parameters: self.private_parameters.clone(),
90 public_parameters: self.public_parameters.clone(),
91 return_values: self.return_values.clone(),
92 assert_messages: self.assert_messages.clone(),
93 }
94 .serialize(serializer)
95 }
96}
97
98impl<'de, F: AcirField + Deserialize<'de>> Deserialize<'de> for Circuit<F> {
99 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
100 let wire = CircuitWire::<F>::deserialize(deserializer)?;
101 Ok(Circuit {
102 function_name: wire.function_name,
103 opcodes: wire.opcodes,
104 private_parameters: wire.private_parameters,
105 public_parameters: wire.public_parameters,
106 return_values: wire.return_values,
107 assert_messages: wire.assert_messages,
108 })
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
114#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
115pub enum ExpressionOrMemory<F> {
116 Expression(Expression<F>),
117 Memory(BlockId),
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
123#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
124pub struct AssertionPayload<F> {
125 pub error_selector: u64,
128 pub payload: Vec<ExpressionOrMemory<F>>,
135}
136
137#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
139pub struct ErrorSelector(u64);
140
141impl ErrorSelector {
142 pub fn new(integer: u64) -> Self {
143 ErrorSelector(integer)
144 }
145
146 pub fn as_u64(&self) -> u64 {
147 self.0
148 }
149}
150
151impl Serialize for ErrorSelector {
152 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153 where
154 S: Serializer,
155 {
156 self.0.to_string().serialize(serializer)
157 }
158}
159
160impl<'de> Deserialize<'de> for ErrorSelector {
161 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
162 where
163 D: Deserializer<'de>,
164 {
165 let s: String = Deserialize::deserialize(deserializer)?;
166 let as_u64 = s.parse().map_err(serde::de::Error::custom)?;
167 Ok(ErrorSelector(as_u64))
168 }
169}
170
171#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
172#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
173pub enum OpcodeLocation {
176 Acir(usize),
177 Brillig { acir_index: usize, brillig_index: usize },
180}
181
182#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
185pub struct AcirOpcodeLocation(usize);
186impl std::fmt::Display for AcirOpcodeLocation {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 write!(f, "{}", self.0)
189 }
190}
191
192impl AcirOpcodeLocation {
193 pub fn new(index: usize) -> Self {
194 AcirOpcodeLocation(index)
195 }
196 pub fn index(&self) -> usize {
197 self.0
198 }
199}
200#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
203pub struct BrilligOpcodeLocation(pub usize);
204
205impl OpcodeLocation {
206 pub fn to_brillig_location(self) -> Option<BrilligOpcodeLocation> {
210 match self {
211 OpcodeLocation::Brillig { brillig_index, .. } => {
212 Some(BrilligOpcodeLocation(brillig_index))
213 }
214 OpcodeLocation::Acir(_) => None,
215 }
216 }
217}
218
219impl std::fmt::Display for OpcodeLocation {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 match self {
222 OpcodeLocation::Acir(index) => write!(f, "{index}"),
223 OpcodeLocation::Brillig { acir_index, brillig_index } => {
224 write!(f, "{acir_index}.{brillig_index}")
225 }
226 }
227 }
228}
229
230#[derive(Error, Debug)]
231pub enum OpcodeLocationFromStrError {
232 #[error("Invalid opcode location string: {0}")]
233 InvalidOpcodeLocationString(String),
234}
235
236impl FromStr for OpcodeLocation {
239 type Err = OpcodeLocationFromStrError;
240 fn from_str(s: &str) -> Result<Self, Self::Err> {
241 let parts: Vec<_> = s.split('.').collect();
242
243 if parts.is_empty() || parts.len() > 2 {
244 return Err(OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()));
245 }
246
247 fn parse_components(parts: Vec<&str>) -> Result<OpcodeLocation, ParseIntError> {
248 match parts.len() {
249 1 => {
250 let index = parts[0].parse()?;
251 Ok(OpcodeLocation::Acir(index))
252 }
253 2 => {
254 let acir_index = parts[0].parse()?;
255 let brillig_index = parts[1].parse()?;
256 Ok(OpcodeLocation::Brillig { acir_index, brillig_index })
257 }
258 _ => unreachable!("`OpcodeLocation` has too many components"),
259 }
260 }
261
262 parse_components(parts)
263 .map_err(|_| OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()))
264 }
265}
266
267impl std::fmt::Display for BrilligOpcodeLocation {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 let index = self.0;
270 write!(f, "{index}")
271 }
272}
273
274impl<F: AcirField> Circuit<F> {
275 pub fn circuit_arguments(&self) -> BTreeSet<Witness> {
277 self.private_parameters.union(&self.public_parameters.0).copied().collect()
278 }
279
280 pub fn public_inputs(&self) -> PublicInputs {
283 let public_inputs =
284 self.public_parameters.0.union(&self.return_values.0).copied().collect();
285 PublicInputs(public_inputs)
286 }
287}
288
289impl<F: Serialize + AcirField> Program<F> {
290 fn compress(buf: Vec<u8>) -> std::io::Result<Vec<u8>> {
292 let mut compressed: Vec<u8> = Vec::new();
293 let mut encoder = flate2::write::GzEncoder::new(&mut compressed, Compression::default());
295 encoder.write_all(&buf)?;
296 encoder.finish()?;
297 Ok(compressed)
298 }
299
300 pub fn serialize_program_with_format(program: &Self, format: serialization::Format) -> Vec<u8> {
302 let program_bytes =
303 serialize_with_format(program, format).expect("expected circuit to be serializable");
304 Self::compress(program_bytes).expect("expected circuit to compress")
305 }
306
307 pub fn serialize_program(program: &Self) -> Vec<u8> {
309 let format = SerializationFormat::from_env().expect("invalid format");
310 Self::serialize_program_with_format(program, format.unwrap_or_default())
311 }
312
313 pub fn serialize_program_base64<S>(program: &Self, s: S) -> Result<S::Ok, S::Error>
315 where
316 S: Serializer,
317 {
318 let program_bytes = Program::serialize_program(program);
319 let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes);
320 s.serialize_str(&encoded_b64)
321 }
322}
323
324impl<F: AcirField + for<'a> Deserialize<'a>> Program<F> {
325 fn read<R: Read>(reader: R) -> std::io::Result<Self> {
327 let mut gz_decoder = flate2::read::GzDecoder::new(reader);
328 let mut buf = Vec::new();
329 gz_decoder.read_to_end(&mut buf)?;
330 let program = deserialize_any_format(&buf)?;
331 Ok(program)
332 }
333
334 pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
336 Program::read(serialized_circuit)
337 }
338
339 pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Self, D::Error>
341 where
342 D: Deserializer<'de>,
343 {
344 let bytecode_b64: String = Deserialize::deserialize(deserializer)?;
345 let program_bytes = base64::engine::general_purpose::STANDARD
346 .decode(bytecode_b64)
347 .map_err(D::Error::custom)?;
348 let circuit = Self::deserialize_program(&program_bytes).map_err(D::Error::custom)?;
349 Ok(circuit)
350 }
351}
352
353impl<F: AcirField> std::fmt::Display for Circuit<F> {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 display_circuit(self, None, f)
356 }
357}
358
359pub fn display_circuit<F: AcirField>(
360 circuit: &Circuit<F>,
361 error_types: Option<&HashMap<ErrorSelector, String>>,
362 f: &mut std::fmt::Formatter<'_>,
363) -> std::fmt::Result {
364 let write_witness_indices =
365 |f: &mut std::fmt::Formatter<'_>, indices: &[u32]| -> Result<(), std::fmt::Error> {
366 write!(f, "[")?;
367 for (index, witness_index) in indices.iter().enumerate() {
368 write!(f, "w{witness_index}")?;
369 if index != indices.len() - 1 {
370 write!(f, ", ")?;
371 }
372 }
373 writeln!(f, "]")
374 };
375
376 write!(f, "private parameters: ")?;
377 write_witness_indices(
378 f,
379 &circuit
380 .private_parameters
381 .iter()
382 .map(|witness| witness.witness_index())
383 .collect::<Vec<_>>(),
384 )?;
385
386 write!(f, "public parameters: ")?;
387 write_witness_indices(f, &circuit.public_parameters.indices())?;
388
389 write!(f, "return values: ")?;
390 write_witness_indices(f, &circuit.return_values.indices())?;
391
392 let assert_messages_by_opcode_location =
393 circuit.assert_messages.iter().cloned().collect::<HashMap<_, _>>();
394
395 for (index, opcode) in circuit.opcodes.iter().enumerate() {
396 display_opcode(opcode, Some(&circuit.return_values), f)?;
397
398 if let Some(error_types) = error_types {
399 let location = OpcodeLocation::Acir(index);
400 if let Some(payload) = assert_messages_by_opcode_location.get(&location)
401 && let Some(message) = error_types.get(&ErrorSelector::new(payload.error_selector))
402 {
403 write!(f, " // {message}")?;
404 }
405 }
406 writeln!(f)?;
407 }
408 Ok(())
409}
410
411impl<F: AcirField> std::fmt::Debug for Circuit<F> {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 std::fmt::Display::fmt(self, f)
414 }
415}
416
417impl<F: AcirField> std::fmt::Display for Program<F> {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 display_program(self, None, f)
420 }
421}
422
423pub fn display_program<F: AcirField>(
424 program: &Program<F>,
425 error_types: Option<&HashMap<ErrorSelector, String>>,
426 f: &mut std::fmt::Formatter<'_>,
427) -> std::fmt::Result {
428 for (func_index, function) in program.functions.iter().enumerate() {
429 writeln!(f, "func {func_index}")?;
430 display_circuit(function, error_types, f)?;
431 writeln!(f)?;
432 }
433 for (func_index, function) in program.unconstrained_functions.iter().enumerate() {
434 writeln!(f, "unconstrained func {func_index}: {}", function.function_name)?;
435 let width = function.bytecode.len().to_string().len();
436 for (index, opcode) in function.bytecode.iter().enumerate() {
437 write!(f, "{index:>width$}: {opcode}")?;
438
439 if let ::brillig::Opcode::IndirectConst { value, .. } = opcode
440 && let Some(value) = value.try_to_u64()
441 && let Some(message) =
442 error_types.and_then(|error_types| error_types.get(&ErrorSelector::new(value)))
443 {
444 write!(f, " // {message:?}")?;
445 }
446
447 writeln!(f)?;
448 }
449 }
450 Ok(())
451}
452
453impl<F: AcirField> std::fmt::Debug for Program<F> {
454 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455 std::fmt::Display::fmt(self, f)
456 }
457}
458
459#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
460#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
461pub struct PublicInputs(pub BTreeSet<Witness>);
462
463impl PublicInputs {
464 pub fn indices(&self) -> Vec<u32> {
466 self.0.iter().map(|witness| witness.witness_index()).collect()
467 }
468
469 pub fn contains(&self, index: usize) -> bool {
470 self.0.contains(&Witness(index as u32))
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::{Circuit, Compression};
477 use crate::circuit::Program;
478 use acir_field::{AcirField, FieldElement};
479 use serde::{Deserialize, Serialize};
480
481 #[test]
482 fn serialization_roundtrip() {
483 let src = "
484 private parameters: []
485 public parameters: [w2, w12]
486 return values: [w4, w12]
487 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
488 BLACKBOX::RANGE input: w1, bits: 8
489 ";
490 let circuit = Circuit::from_str(src).unwrap();
491 let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
492
493 fn read_write<F: Serialize + for<'a> Deserialize<'a> + AcirField>(
494 program: Program<F>,
495 ) -> (Program<F>, Program<F>) {
496 let bytes = Program::serialize_program(&program);
497 let got_program = Program::deserialize_program(&bytes).unwrap();
498 (program, got_program)
499 }
500
501 let (circ, got_circ) = read_write(program);
502 assert_eq!(circ, got_circ);
503 }
504
505 #[test]
506 fn test_serialize() {
507 let src = "
508 private parameters: []
509 public parameters: [w2]
510 return values: [w2]
511 ASSERT 0 = 8
512 BLACKBOX::RANGE input: w1, bits: 8
513 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
514 BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
515 ";
516 let circuit = Circuit::from_str(src).unwrap();
517 let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
518
519 let json = serde_json::to_string_pretty(&program).unwrap();
520
521 let deserialized = serde_json::from_str(&json).unwrap();
522 assert_eq!(program, deserialized);
523 }
524
525 #[test]
526 fn does_not_panic_on_invalid_circuit() {
527 use std::io::Write;
528
529 let bad_circuit = "I'm not an ACIR circuit".as_bytes();
530
531 let mut zipped_bad_circuit = Vec::new();
533 let mut encoder =
534 flate2::write::GzEncoder::new(&mut zipped_bad_circuit, Compression::default());
535 encoder.write_all(bad_circuit).unwrap();
536 encoder.finish().unwrap();
537
538 let deserialization_result: Result<Program<FieldElement>, _> =
539 Program::deserialize_program(&zipped_bad_circuit);
540 assert!(deserialization_result.is_err());
541 }
542
543 #[test]
544 fn circuit_display_snapshot() {
545 let src = "
546 private parameters: []
547 public parameters: [w2]
548 return values: [w2]
549 ASSERT 0 = 2*w1 + 8
550 BLACKBOX::RANGE input: w1, bits: 8
551 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
552 BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
553 ";
554 let circuit = Circuit::from_str(src).unwrap();
555
556 insta::assert_snapshot!(
558 circuit.to_string(),
559 @r"
560 private parameters: []
561 public parameters: [w2]
562 return values: [w2]
563 ASSERT 0 = 2*w1 + 8
564 BLACKBOX::RANGE input: w1, bits: 8
565 BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
566 BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
567 "
568 );
569 }
570
571 mod props {
573 use acir_field::FieldElement;
574 use proptest::prelude::*;
575 use proptest::test_runner::{TestCaseResult, TestRunner};
576
577 use crate::circuit::Program;
578 use crate::native_types::{WitnessMap, WitnessStack};
579 use crate::serialization::*;
580
581 const MAX_SIZE_RANGE: usize = 5;
589 const SIZE_RANGE_KEY: &str = "PROPTEST_MAX_DEFAULT_SIZE_RANGE";
590
591 acir_field::field_wrapper!(TestField, FieldElement);
595
596 impl Arbitrary for TestField {
597 type Parameters = ();
598 type Strategy = BoxedStrategy<Self>;
599
600 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
601 any::<u128>().prop_map(|v| Self(FieldElement::from(v))).boxed()
602 }
603 }
604
605 #[allow(unsafe_code)]
607 fn run_with_max_size_range<T, F>(cases: u32, f: F)
608 where
609 T: Arbitrary,
610 F: Fn(T) -> TestCaseResult,
611 {
612 let orig_size_range = std::env::var(SIZE_RANGE_KEY).ok();
613 if orig_size_range.is_none() {
615 unsafe {
616 std::env::set_var(SIZE_RANGE_KEY, MAX_SIZE_RANGE.to_string());
617 }
618 }
619
620 let mut runner = TestRunner::new(ProptestConfig { cases, ..Default::default() });
621 let result = runner.run(&any::<T>(), f);
622
623 unsafe {
625 std::env::set_var(SIZE_RANGE_KEY, orig_size_range.unwrap_or_default());
626 }
627
628 result.unwrap();
629 }
630
631 #[test]
632 fn prop_program_msgpack_roundtrip() {
633 run_with_max_size_range(100, |(program, compact): (Program<TestField>, bool)| {
634 let bz = msgpack_serialize(&program, compact)?;
635 let de = msgpack_deserialize(&bz)?;
636 prop_assert_eq!(program, de);
637 Ok(())
638 });
639 }
640
641 #[test]
642 fn prop_program_roundtrip() {
643 run_with_max_size_range(10, |program: Program<TestField>| {
644 let bz = Program::serialize_program(&program);
645 let de = Program::deserialize_program(&bz)?;
646 prop_assert_eq!(program, de);
647 Ok(())
648 });
649 }
650
651 #[test]
652 fn prop_witness_stack_msgpack_roundtrip() {
653 run_with_max_size_range(10, |(witness, compact): (WitnessStack<TestField>, bool)| {
654 let bz = msgpack_serialize(&witness, compact)?;
655 let de = msgpack_deserialize(&bz)?;
656 prop_assert_eq!(witness, de);
657 Ok(())
658 });
659 }
660
661 #[test]
662 fn prop_witness_stack_roundtrip() {
663 run_with_max_size_range(10, |witness: WitnessStack<TestField>| {
664 let bz = witness.serialize()?;
665 let de = WitnessStack::deserialize(bz.as_slice())?;
666 prop_assert_eq!(witness, de);
667 Ok(())
668 });
669 }
670
671 #[test]
672 fn prop_witness_map_msgpack_roundtrip() {
673 run_with_max_size_range(10, |(witness, compact): (WitnessMap<TestField>, bool)| {
674 let bz = msgpack_serialize(&witness, compact)?;
675 let de = msgpack_deserialize(&bz)?;
676 prop_assert_eq!(witness, de);
677 Ok(())
678 });
679 }
680
681 #[test]
682 fn prop_witness_map_roundtrip() {
683 run_with_max_size_range(10, |witness: WitnessMap<TestField>| {
684 let bz = witness.serialize()?;
685 let de = WitnessMap::deserialize(bz.as_slice())?;
686 prop_assert_eq!(witness, de);
687 Ok(())
688 });
689 }
690 }
691}