acir/circuit/
mod.rs

1//! Native structures for representing ACIR
2
3pub mod black_box_functions;
4pub mod brillig;
5pub mod opcodes;
6
7use crate::{
8    SerializationFormat,
9    circuit::opcodes::display_opcode,
10    native_types::{Expression, Witness},
11    serialization::{self, deserialize_any_format, serialize_with_format},
12};
13use acir_field::AcirField;
14pub use opcodes::Opcode;
15use thiserror::Error;
16
17use std::{collections::HashMap, io::prelude::*, num::ParseIntError, str::FromStr};
18
19use base64::Engine;
20use flate2::Compression;
21use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as DeserializationError};
22
23use std::collections::BTreeSet;
24
25use self::{brillig::BrilligBytecode, opcodes::BlockId};
26
27/// A program represented by multiple ACIR [circuit][Circuit]'s. The execution trace of these
28/// circuits is dictated by construction of the [crate::native_types::WitnessStack].
29#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
30#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
31pub struct Program<F: AcirField> {
32    pub functions: Vec<Circuit<F>>,
33    pub unconstrained_functions: Vec<BrilligBytecode<F>>,
34}
35
36/// Representation of a single ACIR circuit. The execution trace of this structure
37/// is dictated by the construction of a [crate::native_types::WitnessMap]
38#[derive(Clone, PartialEq, Eq, Default, Hash)]
39#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
40pub struct Circuit<F: AcirField> {
41    /// Name of the function represented by this circuit.
42    pub function_name: String,
43    /// The circuit opcodes representing the relationship between witness values.
44    ///
45    /// The opcodes should be further converted into a backend-specific circuit representation.
46    /// When initial witness inputs are provided, these opcodes can also be used for generating an execution trace.
47    pub opcodes: Vec<Opcode<F>>,
48    /// The set of private inputs to the circuit.
49    pub private_parameters: BTreeSet<Witness>,
50    // ACIR distinguishes between the public inputs which are provided externally or calculated within the circuit and returned.
51    // The elements of these sets may not be mutually exclusive, i.e. a parameter may be returned from the circuit.
52    // All public inputs (parameters and return values) must be provided to the verifier at verification time.
53    /// The set of public inputs provided by the prover.
54    pub public_parameters: PublicInputs,
55    /// The set of public inputs calculated within the circuit.
56    pub return_values: PublicInputs,
57    /// Maps opcode locations to failed assertion payloads.
58    /// The data in the payload is embedded in the circuit to provide useful feedback to users
59    /// when a constraint in the circuit is not satisfied.
60    ///
61    // Note: This should be a BTreeMap, but serde-reflect is creating invalid
62    // c++ code at the moment when it is, due to OpcodeLocation needing a comparison
63    // implementation which is never generated.
64    pub assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,
65}
66
67/// Wire format for `Circuit` — preserves backwards-compatible serialization that includes
68/// `current_witness_index`. The `serde(rename)` ensures this type registers under the same
69/// name ("Circuit") as the public type so that `serde_reflection` traces it correctly.
70#[derive(Serialize, Deserialize)]
71#[serde(rename = "Circuit")]
72struct CircuitWire<F: AcirField> {
73    #[serde(default)]
74    function_name: String,
75    current_witness_index: u32,
76    opcodes: Vec<Opcode<F>>,
77    private_parameters: BTreeSet<Witness>,
78    public_parameters: PublicInputs,
79    return_values: PublicInputs,
80    assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,
81}
82
83impl<F: AcirField + Serialize> Serialize for Circuit<F> {
84    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
85        CircuitWire {
86            function_name: self.function_name.clone(),
87            current_witness_index: 0,
88            opcodes: self.opcodes.clone(),
89            private_parameters: self.private_parameters.clone(),
90            public_parameters: self.public_parameters.clone(),
91            return_values: self.return_values.clone(),
92            assert_messages: self.assert_messages.clone(),
93        }
94        .serialize(serializer)
95    }
96}
97
98impl<'de, F: AcirField + Deserialize<'de>> Deserialize<'de> for Circuit<F> {
99    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
100        let wire = CircuitWire::<F>::deserialize(deserializer)?;
101        Ok(Circuit {
102            function_name: wire.function_name,
103            opcodes: wire.opcodes,
104            private_parameters: wire.private_parameters,
105            public_parameters: wire.public_parameters,
106            return_values: wire.return_values,
107            assert_messages: wire.assert_messages,
108        })
109    }
110}
111
112/// Enumeration of either an [expression][Expression] or a [memory identifier][BlockId].
113#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
114#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
115pub enum ExpressionOrMemory<F> {
116    Expression(Expression<F>),
117    Memory(BlockId),
118}
119
120/// Payload tied to an assertion failure.
121/// This data allows users to specify feedback upon a constraint not being satisfied in the circuit.
122#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
123#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
124pub struct AssertionPayload<F> {
125    /// Selector that maps a hash of either a constant string or an internal compiler error type
126    /// to an ABI type. The ABI type should then be used to appropriately resolve the payload data.
127    pub error_selector: u64,
128    /// The dynamic payload data.
129    ///
130    /// Upon fetching the appropriate ABI type from the `error_selector`, the values
131    /// in this payload can be decoded into the given ABI type.
132    /// The payload is expected to be empty in the case of a constant string
133    /// as the string can be contained entirely within the error type and ABI type.
134    pub payload: Vec<ExpressionOrMemory<F>>,
135}
136
137/// Value for differentiating error types. Used internally by an [AssertionPayload].
138#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
139pub struct ErrorSelector(u64);
140
141impl ErrorSelector {
142    pub fn new(integer: u64) -> Self {
143        ErrorSelector(integer)
144    }
145
146    pub fn as_u64(&self) -> u64 {
147        self.0
148    }
149}
150
151impl Serialize for ErrorSelector {
152    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153    where
154        S: Serializer,
155    {
156        self.0.to_string().serialize(serializer)
157    }
158}
159
160impl<'de> Deserialize<'de> for ErrorSelector {
161    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
162    where
163        D: Deserializer<'de>,
164    {
165        let s: String = Deserialize::deserialize(deserializer)?;
166        let as_u64 = s.parse().map_err(serde::de::Error::custom)?;
167        Ok(ErrorSelector(as_u64))
168    }
169}
170
171#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
172#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
173/// Opcodes are locatable so that callers can
174/// map opcodes to debug information related to their context.
175pub enum OpcodeLocation {
176    Acir(usize),
177    // TODO(https://github.com/noir-lang/noir/issues/5792): We can not get rid of this enum field entirely just yet as this format is still
178    // used for resolving assert messages which is a breaking serialization change.
179    Brillig { acir_index: usize, brillig_index: usize },
180}
181
182/// Opcodes are locatable so that callers can
183/// map opcodes to debug information related to their context.
184#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
185pub struct AcirOpcodeLocation(usize);
186impl std::fmt::Display for AcirOpcodeLocation {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        write!(f, "{}", self.0)
189    }
190}
191
192impl AcirOpcodeLocation {
193    pub fn new(index: usize) -> Self {
194        AcirOpcodeLocation(index)
195    }
196    pub fn index(&self) -> usize {
197        self.0
198    }
199}
200/// Index of Brillig opcode within a list of Brillig opcodes.
201/// To be used by callers for resolving debug information.
202#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
203pub struct BrilligOpcodeLocation(pub usize);
204
205impl OpcodeLocation {
206    // Utility method to allow easily comparing a resolved Brillig location and a debug Brillig location.
207    // This method is useful when fetching Brillig debug locations as this does not need an ACIR index,
208    // and just need the Brillig index.
209    pub fn to_brillig_location(self) -> Option<BrilligOpcodeLocation> {
210        match self {
211            OpcodeLocation::Brillig { brillig_index, .. } => {
212                Some(BrilligOpcodeLocation(brillig_index))
213            }
214            OpcodeLocation::Acir(_) => None,
215        }
216    }
217}
218
219impl std::fmt::Display for OpcodeLocation {
220    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221        match self {
222            OpcodeLocation::Acir(index) => write!(f, "{index}"),
223            OpcodeLocation::Brillig { acir_index, brillig_index } => {
224                write!(f, "{acir_index}.{brillig_index}")
225            }
226        }
227    }
228}
229
230#[derive(Error, Debug)]
231pub enum OpcodeLocationFromStrError {
232    #[error("Invalid opcode location string: {0}")]
233    InvalidOpcodeLocationString(String),
234}
235
236/// The implementation of display and FromStr allows serializing and deserializing a OpcodeLocation to a string.
237/// This is useful when used as key in a map that has to be serialized to JSON/TOML, for example when mapping an opcode to its metadata.
238impl FromStr for OpcodeLocation {
239    type Err = OpcodeLocationFromStrError;
240    fn from_str(s: &str) -> Result<Self, Self::Err> {
241        let parts: Vec<_> = s.split('.').collect();
242
243        if parts.is_empty() || parts.len() > 2 {
244            return Err(OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()));
245        }
246
247        fn parse_components(parts: Vec<&str>) -> Result<OpcodeLocation, ParseIntError> {
248            match parts.len() {
249                1 => {
250                    let index = parts[0].parse()?;
251                    Ok(OpcodeLocation::Acir(index))
252                }
253                2 => {
254                    let acir_index = parts[0].parse()?;
255                    let brillig_index = parts[1].parse()?;
256                    Ok(OpcodeLocation::Brillig { acir_index, brillig_index })
257                }
258                _ => unreachable!("`OpcodeLocation` has too many components"),
259            }
260        }
261
262        parse_components(parts)
263            .map_err(|_| OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()))
264    }
265}
266
267impl std::fmt::Display for BrilligOpcodeLocation {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        let index = self.0;
270        write!(f, "{index}")
271    }
272}
273
274impl<F: AcirField> Circuit<F> {
275    /// Returns all witnesses which are required to execute the circuit successfully.
276    pub fn circuit_arguments(&self) -> BTreeSet<Witness> {
277        self.private_parameters.union(&self.public_parameters.0).copied().collect()
278    }
279
280    /// Returns all public inputs. This includes those provided as parameters to the circuit and those
281    /// computed as return values.
282    pub fn public_inputs(&self) -> PublicInputs {
283        let public_inputs =
284            self.public_parameters.0.union(&self.return_values.0).copied().collect();
285        PublicInputs(public_inputs)
286    }
287}
288
289impl<F: Serialize + AcirField> Program<F> {
290    /// Compress a serialized [Program].
291    fn compress(buf: Vec<u8>) -> std::io::Result<Vec<u8>> {
292        let mut compressed: Vec<u8> = Vec::new();
293        // Compress the data, which should help with formats that uses field names.
294        let mut encoder = flate2::write::GzEncoder::new(&mut compressed, Compression::default());
295        encoder.write_all(&buf)?;
296        encoder.finish()?;
297        Ok(compressed)
298    }
299
300    /// Serialize and compress a [Program] into bytes, using the given format.
301    pub fn serialize_program_with_format(program: &Self, format: serialization::Format) -> Vec<u8> {
302        let program_bytes =
303            serialize_with_format(program, format).expect("expected circuit to be serializable");
304        Self::compress(program_bytes).expect("expected circuit to compress")
305    }
306
307    /// Serialize and compress a [Program] into bytes, using the format from the environment, or the default format.
308    pub fn serialize_program(program: &Self) -> Vec<u8> {
309        let format = SerializationFormat::from_env().expect("invalid format");
310        Self::serialize_program_with_format(program, format.unwrap_or_default())
311    }
312
313    /// Serialize, compress then base64 encode a [Program], using the format from the environment, or the default format,
314    pub fn serialize_program_base64<S>(program: &Self, s: S) -> Result<S::Ok, S::Error>
315    where
316        S: Serializer,
317    {
318        let program_bytes = Program::serialize_program(program);
319        let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes);
320        s.serialize_str(&encoded_b64)
321    }
322}
323
324impl<F: AcirField + for<'a> Deserialize<'a>> Program<F> {
325    /// Decompress and deserialize bytes into a [Program].
326    fn read<R: Read>(reader: R) -> std::io::Result<Self> {
327        let mut gz_decoder = flate2::read::GzDecoder::new(reader);
328        let mut buf = Vec::new();
329        gz_decoder.read_to_end(&mut buf)?;
330        let program = deserialize_any_format(&buf)?;
331        Ok(program)
332    }
333
334    /// Deserialize bytecode.
335    pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
336        Program::read(serialized_circuit)
337    }
338
339    /// Deserialize and base64 decode program
340    pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Self, D::Error>
341    where
342        D: Deserializer<'de>,
343    {
344        let bytecode_b64: String = Deserialize::deserialize(deserializer)?;
345        let program_bytes = base64::engine::general_purpose::STANDARD
346            .decode(bytecode_b64)
347            .map_err(D::Error::custom)?;
348        let circuit = Self::deserialize_program(&program_bytes).map_err(D::Error::custom)?;
349        Ok(circuit)
350    }
351}
352
353impl<F: AcirField> std::fmt::Display for Circuit<F> {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        display_circuit(self, None, f)
356    }
357}
358
359pub fn display_circuit<F: AcirField>(
360    circuit: &Circuit<F>,
361    error_types: Option<&HashMap<ErrorSelector, String>>,
362    f: &mut std::fmt::Formatter<'_>,
363) -> std::fmt::Result {
364    let write_witness_indices =
365        |f: &mut std::fmt::Formatter<'_>, indices: &[u32]| -> Result<(), std::fmt::Error> {
366            write!(f, "[")?;
367            for (index, witness_index) in indices.iter().enumerate() {
368                write!(f, "w{witness_index}")?;
369                if index != indices.len() - 1 {
370                    write!(f, ", ")?;
371                }
372            }
373            writeln!(f, "]")
374        };
375
376    write!(f, "private parameters: ")?;
377    write_witness_indices(
378        f,
379        &circuit
380            .private_parameters
381            .iter()
382            .map(|witness| witness.witness_index())
383            .collect::<Vec<_>>(),
384    )?;
385
386    write!(f, "public parameters: ")?;
387    write_witness_indices(f, &circuit.public_parameters.indices())?;
388
389    write!(f, "return values: ")?;
390    write_witness_indices(f, &circuit.return_values.indices())?;
391
392    let assert_messages_by_opcode_location =
393        circuit.assert_messages.iter().cloned().collect::<HashMap<_, _>>();
394
395    for (index, opcode) in circuit.opcodes.iter().enumerate() {
396        display_opcode(opcode, Some(&circuit.return_values), f)?;
397
398        if let Some(error_types) = error_types {
399            let location = OpcodeLocation::Acir(index);
400            if let Some(payload) = assert_messages_by_opcode_location.get(&location)
401                && let Some(message) = error_types.get(&ErrorSelector::new(payload.error_selector))
402            {
403                write!(f, " // {message}")?;
404            }
405        }
406        writeln!(f)?;
407    }
408    Ok(())
409}
410
411impl<F: AcirField> std::fmt::Debug for Circuit<F> {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        std::fmt::Display::fmt(self, f)
414    }
415}
416
417impl<F: AcirField> std::fmt::Display for Program<F> {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        display_program(self, None, f)
420    }
421}
422
423pub fn display_program<F: AcirField>(
424    program: &Program<F>,
425    error_types: Option<&HashMap<ErrorSelector, String>>,
426    f: &mut std::fmt::Formatter<'_>,
427) -> std::fmt::Result {
428    for (func_index, function) in program.functions.iter().enumerate() {
429        writeln!(f, "func {func_index}")?;
430        display_circuit(function, error_types, f)?;
431        writeln!(f)?;
432    }
433    for (func_index, function) in program.unconstrained_functions.iter().enumerate() {
434        writeln!(f, "unconstrained func {func_index}: {}", function.function_name)?;
435        let width = function.bytecode.len().to_string().len();
436        for (index, opcode) in function.bytecode.iter().enumerate() {
437            write!(f, "{index:>width$}: {opcode}")?;
438
439            if let ::brillig::Opcode::IndirectConst { value, .. } = opcode
440                && let Some(value) = value.try_to_u64()
441                && let Some(message) =
442                    error_types.and_then(|error_types| error_types.get(&ErrorSelector::new(value)))
443            {
444                write!(f, " // {message:?}")?;
445            }
446
447            writeln!(f)?;
448        }
449    }
450    Ok(())
451}
452
453impl<F: AcirField> std::fmt::Debug for Program<F> {
454    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455        std::fmt::Display::fmt(self, f)
456    }
457}
458
459#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
460#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
461pub struct PublicInputs(pub BTreeSet<Witness>);
462
463impl PublicInputs {
464    /// Returns the witness index of each public input
465    pub fn indices(&self) -> Vec<u32> {
466        self.0.iter().map(|witness| witness.witness_index()).collect()
467    }
468
469    pub fn contains(&self, index: usize) -> bool {
470        self.0.contains(&Witness(index as u32))
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::{Circuit, Compression};
477    use crate::circuit::Program;
478    use acir_field::{AcirField, FieldElement};
479    use serde::{Deserialize, Serialize};
480
481    #[test]
482    fn serialization_roundtrip() {
483        let src = "
484        private parameters: []
485        public parameters: [w2, w12]
486        return values: [w4, w12]
487        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
488        BLACKBOX::RANGE input: w1, bits: 8
489        ";
490        let circuit = Circuit::from_str(src).unwrap();
491        let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
492
493        fn read_write<F: Serialize + for<'a> Deserialize<'a> + AcirField>(
494            program: Program<F>,
495        ) -> (Program<F>, Program<F>) {
496            let bytes = Program::serialize_program(&program);
497            let got_program = Program::deserialize_program(&bytes).unwrap();
498            (program, got_program)
499        }
500
501        let (circ, got_circ) = read_write(program);
502        assert_eq!(circ, got_circ);
503    }
504
505    #[test]
506    fn test_serialize() {
507        let src = "
508        private parameters: []
509        public parameters: [w2]
510        return values: [w2]
511        ASSERT 0 = 8
512        BLACKBOX::RANGE input: w1, bits: 8
513        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
514        BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
515        ";
516        let circuit = Circuit::from_str(src).unwrap();
517        let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
518
519        let json = serde_json::to_string_pretty(&program).unwrap();
520
521        let deserialized = serde_json::from_str(&json).unwrap();
522        assert_eq!(program, deserialized);
523    }
524
525    #[test]
526    fn does_not_panic_on_invalid_circuit() {
527        use std::io::Write;
528
529        let bad_circuit = "I'm not an ACIR circuit".as_bytes();
530
531        // We expect to load circuits as compressed artifacts so we compress the junk circuit.
532        let mut zipped_bad_circuit = Vec::new();
533        let mut encoder =
534            flate2::write::GzEncoder::new(&mut zipped_bad_circuit, Compression::default());
535        encoder.write_all(bad_circuit).unwrap();
536        encoder.finish().unwrap();
537
538        let deserialization_result: Result<Program<FieldElement>, _> =
539            Program::deserialize_program(&zipped_bad_circuit);
540        assert!(deserialization_result.is_err());
541    }
542
543    #[test]
544    fn circuit_display_snapshot() {
545        let src = "
546        private parameters: []
547        public parameters: [w2]
548        return values: [w2]
549        ASSERT 0 = 2*w1 + 8
550        BLACKBOX::RANGE input: w1, bits: 8
551        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
552        BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
553        ";
554        let circuit = Circuit::from_str(src).unwrap();
555
556        // All witnesses are expected to be formatted as `w{witness_index}`.
557        insta::assert_snapshot!(
558            circuit.to_string(),
559            @r"
560        private parameters: []
561        public parameters: [w2]
562        return values: [w2]
563        ASSERT 0 = 2*w1 + 8
564        BLACKBOX::RANGE input: w1, bits: 8
565        BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
566        BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
567        "
568        );
569    }
570
571    /// Property based testing for serialization
572    mod props {
573        use acir_field::FieldElement;
574        use proptest::prelude::*;
575        use proptest::test_runner::{TestCaseResult, TestRunner};
576
577        use crate::circuit::Program;
578        use crate::native_types::{WitnessMap, WitnessStack};
579        use crate::serialization::*;
580
581        // It's not possible to set the maximum size of collections via `ProptestConfig`, only an env var,
582        // because e.g. the `VecStrategy` uses `Config::default().max_default_size_range`. On top of that,
583        // `Config::default()` reads a static `DEFAULT_CONFIG`, which gets the env vars only once at the
584        // beginning, so we can't override this on a test-by-test basis, unless we use `fork`,
585        // which is a feature that is currently disabled, because it doesn't work with Wasm.
586        // We could add it as a `dev-dependency` just for this crate, but when I tried it just crashed.
587        // For now using a const so it's obvious we can't set it to different values for different tests.
588        const MAX_SIZE_RANGE: usize = 5;
589        const SIZE_RANGE_KEY: &str = "PROPTEST_MAX_DEFAULT_SIZE_RANGE";
590
591        // Define a wrapper around field so we can implement `Arbitrary`.
592        // NB there are other methods like `arbitrary_field_elements` around the codebase,
593        // but for `proptest_derive::Arbitrary` we need `F: AcirField + Arbitrary`.
594        acir_field::field_wrapper!(TestField, FieldElement);
595
596        impl Arbitrary for TestField {
597            type Parameters = ();
598            type Strategy = BoxedStrategy<Self>;
599
600            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
601                any::<u128>().prop_map(|v| Self(FieldElement::from(v))).boxed()
602            }
603        }
604
605        /// Override the maximum size of collections created by `proptest`.
606        #[allow(unsafe_code)]
607        fn run_with_max_size_range<T, F>(cases: u32, f: F)
608        where
609            T: Arbitrary,
610            F: Fn(T) -> TestCaseResult,
611        {
612            let orig_size_range = std::env::var(SIZE_RANGE_KEY).ok();
613            // The defaults are only read once. If they are already set, leave them be.
614            if orig_size_range.is_none() {
615                unsafe {
616                    std::env::set_var(SIZE_RANGE_KEY, MAX_SIZE_RANGE.to_string());
617                }
618            }
619
620            let mut runner = TestRunner::new(ProptestConfig { cases, ..Default::default() });
621            let result = runner.run(&any::<T>(), f);
622
623            // Restore the original.
624            unsafe {
625                std::env::set_var(SIZE_RANGE_KEY, orig_size_range.unwrap_or_default());
626            }
627
628            result.unwrap();
629        }
630
631        #[test]
632        fn prop_program_msgpack_roundtrip() {
633            run_with_max_size_range(100, |(program, compact): (Program<TestField>, bool)| {
634                let bz = msgpack_serialize(&program, compact)?;
635                let de = msgpack_deserialize(&bz)?;
636                prop_assert_eq!(program, de);
637                Ok(())
638            });
639        }
640
641        #[test]
642        fn prop_program_roundtrip() {
643            run_with_max_size_range(10, |program: Program<TestField>| {
644                let bz = Program::serialize_program(&program);
645                let de = Program::deserialize_program(&bz)?;
646                prop_assert_eq!(program, de);
647                Ok(())
648            });
649        }
650
651        #[test]
652        fn prop_witness_stack_msgpack_roundtrip() {
653            run_with_max_size_range(10, |(witness, compact): (WitnessStack<TestField>, bool)| {
654                let bz = msgpack_serialize(&witness, compact)?;
655                let de = msgpack_deserialize(&bz)?;
656                prop_assert_eq!(witness, de);
657                Ok(())
658            });
659        }
660
661        #[test]
662        fn prop_witness_stack_roundtrip() {
663            run_with_max_size_range(10, |witness: WitnessStack<TestField>| {
664                let bz = witness.serialize()?;
665                let de = WitnessStack::deserialize(bz.as_slice())?;
666                prop_assert_eq!(witness, de);
667                Ok(())
668            });
669        }
670
671        #[test]
672        fn prop_witness_map_msgpack_roundtrip() {
673            run_with_max_size_range(10, |(witness, compact): (WitnessMap<TestField>, bool)| {
674                let bz = msgpack_serialize(&witness, compact)?;
675                let de = msgpack_deserialize(&bz)?;
676                prop_assert_eq!(witness, de);
677                Ok(())
678            });
679        }
680
681        #[test]
682        fn prop_witness_map_roundtrip() {
683            run_with_max_size_range(10, |witness: WitnessMap<TestField>| {
684                let bz = witness.serialize()?;
685                let de = WitnessMap::deserialize(bz.as_slice())?;
686                prop_assert_eq!(witness, de);
687                Ok(())
688            });
689        }
690    }
691}