acir/circuit/
mod.rs

1//! Native structures for representing ACIR
2
3pub mod black_box_functions;
4pub mod brillig;
5pub mod opcodes;
6
7use crate::{
8    SerializationFormat,
9    circuit::opcodes::display_opcode,
10    native_types::{Expression, Witness},
11    serialization::{self, deserialize_any_format, serialize_with_format},
12};
13use acir_field::AcirField;
14use msgpack_tagged::MsgpackTagged;
15pub use opcodes::Opcode;
16use thiserror::Error;
17
18use std::{collections::HashMap, io::prelude::*, num::ParseIntError, str::FromStr};
19
20use base64::Engine;
21use flate2::Compression;
22use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as DeserializationError};
23
24use std::collections::BTreeSet;
25
26use self::{brillig::BrilligBytecode, opcodes::BlockId};
27
28/// A program represented by multiple ACIR [circuit][Circuit]'s. The execution trace of these
29/// circuits is dictated by construction of the [`crate::native_types::WitnessStack`].
30#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash, MsgpackTagged)]
31#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
32#[tagged(allow_unknown_tags)]
33pub struct Program<F: AcirField> {
34    #[tag(0)]
35    pub functions: Vec<Circuit<F>>,
36    #[tag(1)]
37    pub unconstrained_functions: Vec<BrilligBytecode<F>>,
38}
39
40/// Representation of a single ACIR circuit. The execution trace of this structure
41/// is dictated by the construction of a [`crate::native_types::WitnessMap`]
42#[derive(Clone, PartialEq, Eq, Default, Hash, Serialize, Deserialize, MsgpackTagged)]
43#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
44#[tagged(allow_unknown_tags)]
45pub struct Circuit<F: AcirField> {
46    /// Name of the function represented by this circuit.
47    #[tag(0)]
48    pub function_name: String,
49    /// The circuit opcodes representing the relationship between witness values.
50    ///
51    /// The opcodes should be further converted into a backend-specific circuit representation.
52    /// When initial witness inputs are provided, these opcodes can also be used for generating an execution trace.
53    #[tag(1)]
54    pub opcodes: Vec<Opcode<F>>,
55    /// The set of private inputs to the circuit.
56    #[tag(2)]
57    pub private_parameters: BTreeSet<Witness>,
58    // ACIR distinguishes between the public inputs which are provided externally or calculated within the circuit and returned.
59    // The elements of these sets may not be mutually exclusive, i.e. a parameter may be returned from the circuit.
60    // All public inputs (parameters and return values) must be provided to the verifier at verification time.
61    /// The set of public inputs provided by the prover.
62    #[tag(3)]
63    pub public_parameters: PublicInputs,
64    /// The set of public inputs calculated within the circuit.
65    #[tag(4)]
66    pub return_values: PublicInputs,
67    /// Maps opcode locations to failed assertion payloads.
68    /// The data in the payload is embedded in the circuit to provide useful feedback to users
69    /// when a constraint in the circuit is not satisfied.
70    ///
71    // Note: This should be a BTreeMap, but serde-reflect is creating invalid
72    // c++ code at the moment when it is, due to OpcodeLocation needing a comparison
73    // implementation which is never generated.
74    #[tag(5)]
75    pub assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,
76}
77
78/// Enumeration of either an [expression][Expression] or a [memory identifier][BlockId].
79#[derive(Debug, Clone, PartialEq, Eq, Hash)]
80#[derive(Serialize, Deserialize, MsgpackTagged)]
81#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
82pub enum ExpressionOrMemory<F> {
83    #[tag(0)]
84    Expression(Expression<F>),
85    #[tag(1)]
86    Memory(BlockId),
87}
88
89/// Payload tied to an assertion failure.
90/// This data allows users to specify feedback upon a constraint not being satisfied in the circuit.
91#[derive(Debug, Clone, PartialEq, Eq, Hash)]
92#[derive(Serialize, Deserialize, MsgpackTagged)]
93#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
94pub struct AssertionPayload<F> {
95    /// Selector that maps a hash of either a constant string or an internal compiler error type
96    /// to an ABI type. The ABI type should then be used to appropriately resolve the payload data.
97    #[tag(0)]
98    pub error_selector: u64,
99    /// The dynamic payload data.
100    ///
101    /// Upon fetching the appropriate ABI type from the `error_selector`, the values
102    /// in this payload can be decoded into the given ABI type.
103    /// The payload is expected to be empty in the case of a constant string
104    /// as the string can be contained entirely within the error type and ABI type.
105    #[tag(1)]
106    pub payload: Vec<ExpressionOrMemory<F>>,
107}
108
109/// Value for differentiating error types. Used internally by an [`AssertionPayload`].
110#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
111pub struct ErrorSelector(u64);
112
113impl ErrorSelector {
114    pub fn new(integer: u64) -> Self {
115        ErrorSelector(integer)
116    }
117
118    pub fn as_u64(&self) -> u64 {
119        self.0
120    }
121}
122
123impl Serialize for ErrorSelector {
124    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
125    where
126        S: Serializer,
127    {
128        self.0.to_string().serialize(serializer)
129    }
130}
131
132impl<'de> Deserialize<'de> for ErrorSelector {
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: Deserializer<'de>,
136    {
137        let s: String = Deserialize::deserialize(deserializer)?;
138        let as_u64 = s.parse().map_err(serde::de::Error::custom)?;
139        Ok(ErrorSelector(as_u64))
140    }
141}
142
143#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
144#[derive(Serialize, Deserialize, MsgpackTagged)]
145#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
146/// Opcodes are locatable so that callers can
147/// map opcodes to debug information related to their context.
148pub enum OpcodeLocation {
149    #[tag(0)]
150    Acir(usize),
151    // TODO(https://github.com/noir-lang/noir/issues/5792): We can not get rid of this enum field entirely just yet as this format is still
152    // used for resolving assert messages which is a breaking serialization change.
153    #[tag(1)]
154    Brillig {
155        #[tag(0)]
156        acir_index: usize,
157        #[tag(1)]
158        brillig_index: usize,
159    },
160}
161
162/// Opcodes are locatable so that callers can
163/// map opcodes to debug information related to their context.
164#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
165pub struct AcirOpcodeLocation(usize);
166impl std::fmt::Display for AcirOpcodeLocation {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(f, "{}", self.0)
169    }
170}
171
172impl AcirOpcodeLocation {
173    pub fn new(index: usize) -> Self {
174        AcirOpcodeLocation(index)
175    }
176    pub fn index(&self) -> usize {
177        self.0
178    }
179}
180/// Index of Brillig opcode within a list of Brillig opcodes.
181/// To be used by callers for resolving debug information.
182#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
183pub struct BrilligOpcodeLocation(usize);
184
185impl BrilligOpcodeLocation {
186    pub fn new(index: usize) -> Self {
187        BrilligOpcodeLocation(index)
188    }
189    pub fn index(&self) -> usize {
190        self.0
191    }
192}
193
194impl OpcodeLocation {
195    // Utility method to allow easily comparing a resolved Brillig location and a debug Brillig location.
196    // This method is useful when fetching Brillig debug locations as this does not need an ACIR index,
197    // and just need the Brillig index.
198    pub fn to_brillig_location(self) -> Option<BrilligOpcodeLocation> {
199        match self {
200            OpcodeLocation::Brillig { brillig_index, .. } => {
201                Some(BrilligOpcodeLocation::new(brillig_index))
202            }
203            OpcodeLocation::Acir(_) => None,
204        }
205    }
206}
207
208impl std::fmt::Display for OpcodeLocation {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            OpcodeLocation::Acir(index) => write!(f, "{index}"),
212            OpcodeLocation::Brillig { acir_index, brillig_index } => {
213                write!(f, "{acir_index}.{brillig_index}")
214            }
215        }
216    }
217}
218
219#[derive(Error, Debug)]
220pub enum OpcodeLocationFromStrError {
221    #[error("Invalid opcode location string: {0}")]
222    InvalidOpcodeLocationString(String),
223}
224
225/// The implementation of display and `FromStr` allows serializing and deserializing a `OpcodeLocation` to a string.
226/// This is useful when used as key in a map that has to be serialized to JSON/TOML, for example when mapping an opcode to its metadata.
227impl FromStr for OpcodeLocation {
228    type Err = OpcodeLocationFromStrError;
229    fn from_str(s: &str) -> Result<Self, Self::Err> {
230        let parts: Vec<_> = s.split('.').collect();
231
232        if parts.is_empty() || parts.len() > 2 {
233            return Err(OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()));
234        }
235
236        fn parse_components(parts: &[&str]) -> Result<OpcodeLocation, ParseIntError> {
237            match parts.len() {
238                1 => {
239                    let index = parts[0].parse()?;
240                    Ok(OpcodeLocation::Acir(index))
241                }
242                2 => {
243                    let acir_index = parts[0].parse()?;
244                    let brillig_index = parts[1].parse()?;
245                    Ok(OpcodeLocation::Brillig { acir_index, brillig_index })
246                }
247                _ => unreachable!("`OpcodeLocation` has too many components"),
248            }
249        }
250
251        parse_components(&parts)
252            .map_err(|_| OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()))
253    }
254}
255
256impl std::fmt::Display for BrilligOpcodeLocation {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        let index = self.0;
259        write!(f, "{index}")
260    }
261}
262
263impl<F: AcirField> Circuit<F> {
264    /// Returns all witnesses which are required to execute the circuit successfully.
265    pub fn circuit_arguments(&self) -> BTreeSet<Witness> {
266        self.private_parameters.union(&self.public_parameters.0).copied().collect()
267    }
268
269    /// Returns all public inputs. This includes those provided as parameters to the circuit and those
270    /// computed as return values.
271    pub fn public_inputs(&self) -> PublicInputs {
272        let public_inputs =
273            self.public_parameters.0.union(&self.return_values.0).copied().collect();
274        PublicInputs(public_inputs)
275    }
276}
277
278impl<F: Serialize + AcirField + MsgpackTagged> Program<F> {
279    /// Compress a serialized [Program].
280    fn compress(buf: &[u8]) -> std::io::Result<Vec<u8>> {
281        let mut compressed: Vec<u8> = Vec::new();
282        // Compress the data, which should help with formats that uses field names.
283        let mut encoder = flate2::write::GzEncoder::new(&mut compressed, Compression::default());
284        encoder.write_all(buf)?;
285        encoder.finish()?;
286        Ok(compressed)
287    }
288
289    /// Serialize and compress a [Program] into bytes, using the given format.
290    pub fn serialize_program_with_format(program: &Self, format: serialization::Format) -> Vec<u8> {
291        let program_bytes =
292            serialize_with_format(program, format).expect("expected circuit to be serializable");
293        Self::compress(&program_bytes).expect("expected circuit to compress")
294    }
295
296    /// Serialize and compress a [Program] into bytes, using the format from the environment, or the default format.
297    pub fn serialize_program(program: &Self) -> Vec<u8> {
298        let format = SerializationFormat::from_env().expect("invalid format");
299        Self::serialize_program_with_format(program, format.unwrap_or_default())
300    }
301
302    /// Serialize, compress then base64 encode a [Program], using the format from the environment, or the default format,
303    pub fn serialize_program_base64<S>(program: &Self, s: S) -> Result<S::Ok, S::Error>
304    where
305        S: Serializer,
306    {
307        let program_bytes = Program::serialize_program(program);
308        let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes);
309        s.serialize_str(&encoded_b64)
310    }
311}
312
313impl<F: AcirField + for<'a> Deserialize<'a> + MsgpackTagged> Program<F> {
314    /// Decompress and deserialize bytes into a [Program].
315    fn read<R: Read>(reader: R) -> std::io::Result<Self> {
316        let mut gz_decoder = flate2::read::GzDecoder::new(reader);
317        let mut buf = Vec::new();
318        gz_decoder.read_to_end(&mut buf)?;
319        let program = deserialize_any_format(&buf)?;
320        Ok(program)
321    }
322
323    /// Deserialize bytecode.
324    pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
325        Program::read(serialized_circuit)
326    }
327
328    /// Deserialize and base64 decode program
329    pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Self, D::Error>
330    where
331        D: Deserializer<'de>,
332    {
333        let bytecode_b64: String = Deserialize::deserialize(deserializer)?;
334        let program_bytes = base64::engine::general_purpose::STANDARD
335            .decode(bytecode_b64)
336            .map_err(D::Error::custom)?;
337        let circuit = Self::deserialize_program(&program_bytes).map_err(D::Error::custom)?;
338        Ok(circuit)
339    }
340}
341
342impl<F: AcirField> std::fmt::Display for Circuit<F> {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        display_circuit(self, None, f)
345    }
346}
347
348pub fn display_circuit<F: AcirField>(
349    circuit: &Circuit<F>,
350    error_types: Option<&HashMap<ErrorSelector, String>>,
351    f: &mut std::fmt::Formatter<'_>,
352) -> std::fmt::Result {
353    let write_witness_indices =
354        |f: &mut std::fmt::Formatter<'_>, indices: &[u32]| -> Result<(), std::fmt::Error> {
355            write!(f, "[")?;
356            for (index, witness_index) in indices.iter().enumerate() {
357                write!(f, "w{witness_index}")?;
358                if index != indices.len() - 1 {
359                    write!(f, ", ")?;
360                }
361            }
362            writeln!(f, "]")
363        };
364
365    write!(f, "private parameters: ")?;
366    write_witness_indices(
367        f,
368        &circuit
369            .private_parameters
370            .iter()
371            .map(|witness| witness.witness_index())
372            .collect::<Vec<_>>(),
373    )?;
374
375    write!(f, "public parameters: ")?;
376    write_witness_indices(f, &circuit.public_parameters.indices())?;
377
378    write!(f, "return values: ")?;
379    write_witness_indices(f, &circuit.return_values.indices())?;
380
381    let assert_messages_by_opcode_location =
382        circuit.assert_messages.iter().cloned().collect::<HashMap<_, _>>();
383
384    for (index, opcode) in circuit.opcodes.iter().enumerate() {
385        display_opcode(opcode, Some(&circuit.return_values), f)?;
386
387        if let Some(error_types) = error_types {
388            let location = OpcodeLocation::Acir(index);
389            if let Some(payload) = assert_messages_by_opcode_location.get(&location)
390                && let Some(message) = error_types.get(&ErrorSelector::new(payload.error_selector))
391            {
392                write!(f, " // {message}")?;
393            }
394        }
395        writeln!(f)?;
396    }
397    Ok(())
398}
399
400impl<F: AcirField> std::fmt::Debug for Circuit<F> {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        std::fmt::Display::fmt(self, f)
403    }
404}
405
406impl<F: AcirField> std::fmt::Display for Program<F> {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        display_program(self, None, f)
409    }
410}
411
412pub fn display_program<F: AcirField>(
413    program: &Program<F>,
414    error_types: Option<&HashMap<ErrorSelector, String>>,
415    f: &mut std::fmt::Formatter<'_>,
416) -> std::fmt::Result {
417    for (func_index, function) in program.functions.iter().enumerate() {
418        writeln!(f, "func {func_index}")?;
419        display_circuit(function, error_types, f)?;
420        writeln!(f)?;
421    }
422    for (func_index, function) in program.unconstrained_functions.iter().enumerate() {
423        writeln!(f, "unconstrained func {func_index}: {}", function.function_name)?;
424        let width = function.bytecode.len().to_string().len();
425        for (index, opcode) in function.bytecode.iter().enumerate() {
426            write!(f, "{index:>width$}: {opcode}")?;
427
428            if let ::brillig::Opcode::IndirectConst { value, .. } = opcode
429                && let Some(value) = value.try_to_u64()
430                && let Some(message) =
431                    error_types.and_then(|error_types| error_types.get(&ErrorSelector::new(value)))
432            {
433                write!(f, " // {message:?}")?;
434            }
435
436            writeln!(f)?;
437        }
438    }
439    Ok(())
440}
441
442impl<F: AcirField> std::fmt::Debug for Program<F> {
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        std::fmt::Display::fmt(self, f)
445    }
446}
447
448#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, Hash, MsgpackTagged)]
449#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
450pub struct PublicInputs(pub BTreeSet<Witness>);
451
452impl PublicInputs {
453    /// Returns the witness index of each public input
454    pub fn indices(&self) -> Vec<u32> {
455        self.0.iter().map(|witness| witness.witness_index()).collect()
456    }
457
458    pub fn contains(&self, index: usize) -> bool {
459        self.0.contains(&Witness::new(index as u32))
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::{Circuit, Compression};
466    use crate::circuit::Program;
467    use acir_field::{AcirField, FieldElement};
468    use msgpack_tagged::MsgpackTagged;
469    use serde::{Deserialize, Serialize};
470
471    #[test]
472    fn serialization_roundtrip() {
473        let src = "
474        private parameters: []
475        public parameters: [w2, w12]
476        return values: [w4, w12]
477        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
478        BLACKBOX::RANGE input: w1, bits: 8
479        ";
480        let circuit = Circuit::from_str(src).unwrap();
481        let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
482
483        fn read_write<F: Serialize + for<'a> Deserialize<'a> + AcirField + MsgpackTagged>(
484            program: Program<F>,
485        ) -> (Program<F>, Program<F>) {
486            let bytes = Program::serialize_program(&program);
487            let got_program = Program::deserialize_program(&bytes).unwrap();
488            (program, got_program)
489        }
490
491        let (circ, got_circ) = read_write(program);
492        assert_eq!(circ, got_circ);
493    }
494
495    #[test]
496    fn test_serialize() {
497        let src = "
498        private parameters: []
499        public parameters: [w2]
500        return values: [w2]
501        ASSERT 0 = 8
502        BLACKBOX::RANGE input: w1, bits: 8
503        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
504        BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
505        ";
506        let circuit = Circuit::from_str(src).unwrap();
507        let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
508
509        let json = serde_json::to_string_pretty(&program).unwrap();
510
511        let deserialized = serde_json::from_str(&json).unwrap();
512        assert_eq!(program, deserialized);
513    }
514
515    #[test]
516    fn does_not_panic_on_invalid_circuit() {
517        use std::io::Write;
518
519        let bad_circuit = "I'm not an ACIR circuit".as_bytes();
520
521        // We expect to load circuits as compressed artifacts so we compress the junk circuit.
522        let mut zipped_bad_circuit = Vec::new();
523        let mut encoder =
524            flate2::write::GzEncoder::new(&mut zipped_bad_circuit, Compression::default());
525        encoder.write_all(bad_circuit).unwrap();
526        encoder.finish().unwrap();
527
528        let deserialization_result: Result<Program<FieldElement>, _> =
529            Program::deserialize_program(&zipped_bad_circuit);
530        assert!(deserialization_result.is_err());
531    }
532
533    #[test]
534    fn circuit_display_snapshot() {
535        let src = "
536        private parameters: []
537        public parameters: [w2]
538        return values: [w2]
539        ASSERT 0 = 2*w1 + 8
540        BLACKBOX::RANGE input: w1, bits: 8
541        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
542        BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
543        ";
544        let circuit = Circuit::from_str(src).unwrap();
545
546        // All witnesses are expected to be formatted as `w{witness_index}`.
547        insta::assert_snapshot!(
548            circuit.to_string(),
549            @r"
550        private parameters: []
551        public parameters: [w2]
552        return values: [w2]
553        ASSERT 0 = 2*w1 + 8
554        BLACKBOX::RANGE input: w1, bits: 8
555        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
556        BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
557        "
558        );
559    }
560
561    /// This test is a reminder that when we slap `#[derive(MsgpackTagged)]` on type,
562    /// we only get immediate compilation errors for non-generic fields that
563    /// don't implement `MsgpackTagged`; for generic types like `Circuit<F>`
564    /// in `Program<F>`, we may or may not have an implementation, depending on `F`.
565    ///
566    /// A test like this brings out all the concrete missing implementations by
567    /// choosing a particular `F`.
568    #[test]
569    fn ensure_program_is_msgpack_tagged() {
570        fn assert_impl<T: MsgpackTagged>() {}
571        assert_impl::<Program<FieldElement>>();
572    }
573
574    /// Property based testing for serialization
575    mod props {
576        use acir_field::FieldElement;
577        use msgpack_tagged::MsgpackTagged;
578        use proptest::prelude::*;
579        use proptest::test_runner::{TestCaseResult, TestRunner};
580
581        use crate::circuit::Program;
582        use crate::native_types::{WitnessMap, WitnessStack};
583        use crate::serialization::*;
584
585        // It's not possible to set the maximum size of collections via `ProptestConfig`, only an env var,
586        // because e.g. the `VecStrategy` uses `Config::default().max_default_size_range`. On top of that,
587        // `Config::default()` reads a static `DEFAULT_CONFIG`, which gets the env vars only once at the
588        // beginning, so we can't override this on a test-by-test basis, unless we use `fork`,
589        // which is a feature that is currently disabled, because it doesn't work with Wasm.
590        // We could add it as a `dev-dependency` just for this crate, but when I tried it just crashed.
591        // For now using a const so it's obvious we can't set it to different values for different tests.
592        const MAX_SIZE_RANGE: usize = 5;
593        const SIZE_RANGE_KEY: &str = "PROPTEST_MAX_DEFAULT_SIZE_RANGE";
594
595        // Define a wrapper around field so we can implement `Arbitrary`.
596        // NB there are other methods like `arbitrary_field_elements` around the codebase,
597        // but for `proptest_derive::Arbitrary` we need `F: AcirField + Arbitrary`.
598        acir_field::field_wrapper!(TestField, FieldElement);
599
600        impl Arbitrary for TestField {
601            type Parameters = ();
602            type Strategy = BoxedStrategy<Self>;
603
604            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
605                any::<u128>().prop_map(|v| Self(FieldElement::from(v))).boxed()
606            }
607        }
608
609        impl MsgpackTagged for TestField {
610            const TAGGED: msgpack_tagged::Tagged = msgpack_tagged::Tagged::empty_product();
611            fn register_into(_reg: &mut msgpack_tagged::TagRegistry) {}
612        }
613
614        /// Override the maximum size of collections created by `proptest`.
615        #[allow(unsafe_code)]
616        fn run_with_max_size_range<T, F>(cases: u32, f: F)
617        where
618            T: Arbitrary,
619            F: Fn(T) -> TestCaseResult,
620        {
621            let orig_size_range = std::env::var(SIZE_RANGE_KEY).ok();
622            // The defaults are only read once. If they are already set, leave them be.
623            if orig_size_range.is_none() {
624                unsafe {
625                    std::env::set_var(SIZE_RANGE_KEY, MAX_SIZE_RANGE.to_string());
626                }
627            }
628
629            let mut runner = TestRunner::new(ProptestConfig { cases, ..Default::default() });
630            let result = runner.run(&any::<T>(), f);
631
632            // Restore the original.
633            unsafe {
634                std::env::set_var(SIZE_RANGE_KEY, orig_size_range.unwrap_or_default());
635            }
636
637            result.unwrap();
638        }
639
640        /// Round-trip a `Program` through each `Format` variant chosen
641        /// arbitrarily. Uses `serialize_with_format` + `deserialize_any_format`
642        /// so the framing-byte dispatch is exercised too — picks the right
643        /// encoder/decoder pair per format under one test name.
644        #[test]
645        fn prop_program_arb_roundtrip() {
646            run_with_max_size_range(100, |(program, format): (Program<TestField>, Format)| {
647                let bz = serialize_with_format(&program, format)?;
648                let de = deserialize_any_format(&bz)?;
649                prop_assert_eq!(program, de);
650                Ok(())
651            });
652        }
653
654        #[test]
655        fn prop_program_roundtrip() {
656            run_with_max_size_range(10, |program: Program<TestField>| {
657                let bz = Program::serialize_program(&program);
658                let de = Program::deserialize_program(&bz)?;
659                prop_assert_eq!(program, de);
660                Ok(())
661            });
662        }
663
664        #[test]
665        fn prop_witness_stack_arb_roundtrip() {
666            run_with_max_size_range(10, |(witness, format): (WitnessStack<TestField>, Format)| {
667                let bz = serialize_with_format(&witness, format)?;
668                let de = deserialize_any_format(&bz)?;
669                prop_assert_eq!(witness, de);
670                Ok(())
671            });
672        }
673
674        #[test]
675        fn prop_witness_stack_roundtrip() {
676            run_with_max_size_range(10, |witness: WitnessStack<TestField>| {
677                let bz = witness.serialize()?;
678                let de = WitnessStack::deserialize(bz.as_slice())?;
679                prop_assert_eq!(witness, de);
680                Ok(())
681            });
682        }
683
684        #[test]
685        fn prop_witness_map_arb_roundtrip() {
686            run_with_max_size_range(10, |(witness, format): (WitnessMap<TestField>, Format)| {
687                let bz = serialize_with_format(&witness, format)?;
688                let de = deserialize_any_format(&bz)?;
689                prop_assert_eq!(witness, de);
690                Ok(())
691            });
692        }
693
694        #[test]
695        fn prop_witness_map_roundtrip() {
696            run_with_max_size_range(10, |witness: WitnessMap<TestField>| {
697                let bz = witness.serialize()?;
698                let de = WitnessMap::deserialize(bz.as_slice())?;
699                prop_assert_eq!(witness, de);
700                Ok(())
701            });
702        }
703    }
704}