1#![cfg_attr(not(test), forbid(unsafe_code))] #![cfg_attr(not(test), warn(unused_crate_dependencies, unused_extern_crates))]
9
10#[doc = include_str!("../README.md")]
11pub mod circuit;
12pub mod native_types;
13mod parser;
14mod serialization;
15
16pub use acir_field;
17pub use acir_field::{AcirField, FieldElement};
18pub use brillig;
19pub use circuit::black_box_functions::BlackBoxFunc;
20pub use circuit::opcodes::InvalidInputBitSize;
21pub use parser::parse_opcodes;
22pub use serialization::Format as SerializationFormat;
23
24#[cfg(test)]
25mod reflection {
26 use std::{
37 collections::BTreeMap,
38 fs::File,
39 hash::BuildHasher,
40 io::Write,
41 path::{Path, PathBuf},
42 };
43
44 use acir_field::{AcirField, FieldElement};
45 use brillig::{
46 BinaryFieldOp, BinaryIntOp, BitSize, BlackBoxOp, HeapValueType, IntegerBitSize,
47 MemoryAddress, Opcode as BrilligOpcode, ValueOrArray,
48 };
49 use regex::Regex;
50 use serde::{Deserialize, Serialize};
51 use serde_generate::CustomCode;
52 use serde_reflection::{
53 ContainerFormat, Format, Named, Registry, Samples, Tracer, TracerConfig, VariantFormat,
54 };
55
56 use crate::{
57 circuit::{
58 AssertionPayload, Circuit, ExpressionOrMemory, Opcode, OpcodeLocation, Program,
59 brillig::{BrilligInputs, BrilligOutputs},
60 opcodes::{BlackBoxFuncCall, BlockType, FunctionInput, MemOp},
61 },
62 native_types::{Witness, WitnessMap, WitnessStack},
63 };
64
65 #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
76 struct ProgramWithoutBrillig<F: AcirField> {
77 pub functions: Vec<Circuit<F>>,
78 pub unconstrained_functions: (),
83 }
84
85 #[test]
86 fn serde_acir_cpp_codegen() {
87 let config = TracerConfig::default().record_samples_for_structs(true);
91 let mut tracer = Tracer::new(config);
92
93 let mut samples = Samples::new();
94 tracer
95 .trace_value(
96 &mut samples,
97 &MemOp::<FieldElement>::read_at_mem_index(Witness(0), Witness(0)),
98 )
99 .unwrap();
100 tracer
101 .trace_value(
102 &mut samples,
103 &MemOp::<FieldElement>::write_to_mem_index(Witness(0), Witness(0)),
104 )
105 .unwrap();
106
107 tracer.trace_simple_type::<BlockType>().unwrap();
108 tracer.trace_simple_type::<Program<FieldElement>>().unwrap();
109 tracer.trace_simple_type::<ProgramWithoutBrillig<FieldElement>>().unwrap();
110 tracer.trace_simple_type::<Circuit<FieldElement>>().unwrap();
111 tracer.trace_type::<Opcode<FieldElement>>(&samples).unwrap();
112 tracer.trace_simple_type::<OpcodeLocation>().unwrap();
113 tracer.trace_simple_type::<BinaryFieldOp>().unwrap();
114 tracer.trace_simple_type::<FunctionInput<FieldElement>>().unwrap();
115 tracer.trace_simple_type::<FunctionInput<FieldElement>>().unwrap();
116 tracer.trace_simple_type::<BlackBoxFuncCall<FieldElement>>().unwrap();
117 tracer.trace_simple_type::<BrilligInputs<FieldElement>>().unwrap();
118 tracer.trace_simple_type::<BrilligOutputs>().unwrap();
119 tracer.trace_simple_type::<BrilligOpcode<FieldElement>>().unwrap();
120 tracer.trace_simple_type::<BinaryIntOp>().unwrap();
121 tracer.trace_simple_type::<BlackBoxOp>().unwrap();
122 tracer.trace_simple_type::<ValueOrArray>().unwrap();
123 tracer.trace_simple_type::<HeapValueType>().unwrap();
124 tracer.trace_simple_type::<AssertionPayload<FieldElement>>().unwrap();
125 tracer.trace_simple_type::<ExpressionOrMemory<FieldElement>>().unwrap();
126 tracer.trace_simple_type::<BitSize>().unwrap();
127 tracer.trace_simple_type::<IntegerBitSize>().unwrap();
128 tracer.trace_simple_type::<MemoryAddress>().unwrap();
129
130 serde_cpp_codegen(
131 "Acir",
132 PathBuf::from("./codegen/acir.cpp").as_path(),
133 &tracer.registry().unwrap(),
134 CustomCode::default(),
135 );
136 }
137
138 #[test]
139 fn serde_witness_map_cpp_codegen() {
140 let mut tracer = Tracer::new(TracerConfig::default());
141 tracer.trace_simple_type::<Witness>().unwrap();
142 tracer.trace_simple_type::<WitnessMap<FieldElement>>().unwrap();
143 tracer.trace_simple_type::<WitnessStack<FieldElement>>().unwrap();
144
145 let namespace = "Witnesses";
146 let mut code = CustomCode::default();
147 code.insert(
150 vec![namespace.to_string(), "Witness".to_string()],
151 "bool operator<(Witness const& rhs) const { return value < rhs.value; }".to_string(),
152 );
153
154 serde_cpp_codegen(
155 namespace,
156 PathBuf::from("./codegen/witness.cpp").as_path(),
157 &tracer.registry().unwrap(),
158 code,
159 );
160 }
161
162 fn serde_cpp_codegen(namespace: &str, path: &Path, registry: &Registry, code: CustomCode) {
168 let old_hash = if path.is_file() {
169 let old_source = std::fs::read(path).expect("failed to read existing code");
170 let old_source = String::from_utf8(old_source).expect("old source not UTF-8");
171 Some(rustc_hash::FxBuildHasher.hash_one(&old_source))
172 } else {
173 None
174 };
175 let msgpack_code = MsgPackCodeGenerator::generate(
176 namespace,
177 registry,
178 code,
179 MsgPackCodeConfig::from_env(),
180 );
181
182 let mut source = Vec::new();
184 let config = serde_generate::CodeGeneratorConfig::new(namespace.to_string())
187 .with_encodings(vec![])
188 .with_custom_code(msgpack_code);
189 let generator = serde_generate::cpp::CodeGenerator::new(&config);
190 generator.output(&mut source, registry).expect("failed to generate C++ code");
191
192 let mut source = String::from_utf8(source).expect("not a UTF-8 string");
194 replace_throw(&mut source);
195 MsgPackCodeGenerator::add_preamble(&mut source);
196 MsgPackCodeGenerator::add_helpers(&mut source, namespace);
197 MsgPackCodeGenerator::replace_array_with_shared_ptr(&mut source);
198
199 if !should_overwrite()
200 && let Some(old_hash) = old_hash
201 {
202 let new_hash = rustc_hash::FxBuildHasher.hash_one(&source);
203 assert_eq!(new_hash, old_hash, "Serialization format has changed",);
204 }
205
206 write_to_file(source.as_bytes(), path);
207 }
208
209 fn env_flag(name: &str, default: bool) -> bool {
211 let Ok(s) = std::env::var(name) else {
212 return default;
213 };
214 match s.as_str() {
215 "1" | "true" | "yes" => true,
216 "0" | "false" | "no" => false,
217 _ => default,
218 }
219 }
220
221 fn should_overwrite() -> bool {
224 env_flag("NOIR_CODEGEN_OVERWRITE", false)
225 }
226
227 fn write_to_file(bytes: &[u8], path: &Path) -> String {
228 let display = path.display();
229
230 let parent_dir = path.parent().unwrap();
231 if !parent_dir.is_dir() {
232 std::fs::create_dir_all(parent_dir).unwrap();
233 }
234
235 let mut file = match File::create(path) {
236 Err(why) => panic!("couldn't create {display}: {why}"),
237 Ok(file) => file,
238 };
239
240 match file.write_all(bytes) {
241 Err(why) => panic!("couldn't write to {display}: {why}"),
242 Ok(_) => display.to_string(),
243 }
244 }
245
246 fn replace_throw(source: &mut String) {
252 *source = source.replace("throw serde::deserialization_error", "throw_or_abort");
253 }
254
255 struct MsgPackCodeConfig {
256 pack_compact: bool,
258 no_pack: bool,
260 }
261
262 impl MsgPackCodeConfig {
263 fn from_env() -> Self {
264 Self {
265 pack_compact: env_flag("NOIR_CODEGEN_PACK_COMPACT", true),
267 no_pack: env_flag("NOIR_CODEGEN_NO_PACK", true),
269 }
270 }
271 }
272
273 struct MsgPackCodeGenerator {
276 config: MsgPackCodeConfig,
277 namespace: Vec<String>,
278 code: CustomCode,
279 }
280
281 impl MsgPackCodeGenerator {
282 pub(crate) fn add_preamble(source: &mut String) {
284 let inc = r#"#include "serde.hpp""#;
285 let pos = source.find(inc).expect("serde.hpp missing");
286 source.insert_str(
287 pos + inc.len(),
288 "\n#include \"barretenberg/serialize/msgpack_impl.hpp\"",
289 );
290 }
291
292 pub(crate) fn add_helpers(source: &mut String, namespace: &str) {
294 let helpers = r#"
298 struct Helpers {
299 static std::map<std::string, msgpack::object const*> make_kvmap(
300 msgpack::object const& o,
301 std::string const& name
302 ) {
303 if (o.type != msgpack::type::MAP) {
304 std::cerr << o << std::endl;
305 throw_or_abort("expected MAP for " + name);
306 }
307 std::map<std::string, msgpack::object const*> kvmap;
308 for (uint32_t i = 0; i < o.via.map.size; ++i) {
309 if (o.via.map.ptr[i].key.type != msgpack::type::STR) {
310 std::cerr << o << std::endl;
311 throw_or_abort("expected STR for keys of " + name);
312 }
313 kvmap.emplace(
314 std::string(
315 o.via.map.ptr[i].key.via.str.ptr,
316 o.via.map.ptr[i].key.via.str.size),
317 &o.via.map.ptr[i].val);
318 }
319 return kvmap;
320 }
321
322 template<typename T>
323 static void conv_fld_from_kvmap(
324 std::map<std::string, msgpack::object const*> const& kvmap,
325 std::string const& struct_name,
326 std::string const& field_name,
327 T& field,
328 bool is_optional
329 ) {
330 auto it = kvmap.find(field_name);
331 if (it != kvmap.end()) {
332 try {
333 it->second->convert(field);
334 } catch (const msgpack::type_error&) {
335 std::cerr << *it->second << std::endl;
336 throw_or_abort("error converting into field " + struct_name + "::" + field_name);
337 }
338 } else if (!is_optional) {
339 throw_or_abort("missing field: " + struct_name + "::" + field_name);
340 }
341 }
342
343 template<typename T>
344 static void conv_fld_from_array(
345 msgpack::object_array const& array,
346 std::string const& struct_name,
347 std::string const& field_name,
348 T& field,
349 uint32_t index
350 ) {
351 if (index >= array.size) {
352 throw_or_abort("index out of bounds: " + struct_name + "::" + field_name + " at " + std::to_string(index));
353 }
354 auto element = array.ptr[index];
355 try {
356 element.convert(field);
357 } catch (const msgpack::type_error&) {
358 std::cerr << element << std::endl;
359 throw_or_abort("error converting into field " + struct_name + "::" + field_name);
360 }
361 }
362 };
363 "#;
364 let pos = source.find(&format!("namespace {namespace}")).expect("namespace");
366 source.insert_str(pos, &format!("namespace {namespace} {{{helpers}}}\n\n"));
367 }
368
369 fn replace_array_with_shared_ptr(source: &mut String) {
371 let re = Regex::new(r#"std::array<\s*([^,<>]+?)\s*,\s*([0-9]+)\s*>"#)
373 .expect("failed to create regex");
374
375 let fixed =
376 re.replace_all(source, "std::shared_ptr<std::array<${1}, ${2}>>").into_owned();
377
378 *source = fixed;
379 }
380
381 pub(crate) fn generate(
383 namespace: &str,
384 registry: &Registry,
385 code: CustomCode,
386 config: MsgPackCodeConfig,
387 ) -> CustomCode {
388 let mut g = Self { namespace: vec![namespace.to_string()], code, config };
389 for (name, container) in registry {
390 g.generate_container(name, container);
391 }
392 g.code
393 }
394
395 fn add_code(&mut self, name: &str, code: &str) {
397 let mut ns = self.namespace.clone();
398 ns.push(name.to_string());
399 let c = self.code.entry(ns).or_default();
400 if !c.is_empty() && code.contains('\n') {
401 c.push('\n');
402 }
403 c.push_str(code);
404 c.push('\n');
405 }
406
407 fn generate_container(&mut self, name: &str, container: &ContainerFormat) {
408 use serde_reflection::ContainerFormat::*;
409 match container {
410 UnitStruct => {
411 self.generate_unit_struct(name);
412 }
413 NewTypeStruct(_format) => {
414 self.generate_newtype(name);
415 }
416 TupleStruct(formats) => {
417 self.generate_tuple(name, formats);
418 }
419 Struct(fields) => {
420 self.generate_struct(name, fields);
421 }
422 Enum(variants) => {
423 self.generate_enum(name, variants);
424 }
425 }
426 }
427
428 fn generate_unit_struct(&mut self, name: &str) {
430 self.msgpack_pack(name, "");
434 self.msgpack_unpack(name, "");
435 }
436
437 fn generate_struct(&mut self, name: &str, fields: &[Named<Format>]) {
439 fn is_unit(field: &Named<Format>) -> bool {
456 matches!(field.value, Format::Unit)
457 }
458
459 let non_unit_field_count = fields.iter().filter(|f| !is_unit(f)).count();
460
461 self.msgpack_pack(name, &{
462 if self.config.pack_compact {
463 let mut body = format!(
465 "
466 packer.pack_array({});",
467 fields.len()
468 );
469 for field in fields {
470 let field_name = &field.name;
471 body.push_str(&format!(
472 r#"
473 packer.pack({field_name});"#
474 ));
475 }
476 body
477 } else {
478 let mut body = format!(
480 "
481 packer.pack_map({non_unit_field_count});",
482 );
483 for field in fields {
484 if is_unit(field) {
485 continue;
486 }
487 let field_name = &field.name;
488 body.push_str(&format!(
489 r#"
490 packer.pack(std::make_pair("{field_name}", {field_name}));"#
491 ));
492 }
493 body
494 }
495 });
496
497 self.msgpack_unpack(name, &{
498 let mut body = format!(
502 r#"
503 std::string name = "{name}";
504 if (o.type == msgpack::type::MAP) {{
505 auto kvmap = Helpers::make_kvmap(o, name);"#
506 );
507 for field in fields {
509 if is_unit(field) {
510 continue;
511 }
512 let field_name = &field.name;
513 let is_optional = matches!(field.value, Format::Option(_));
514 body.push_str(&format!(
516 r#"
517 Helpers::conv_fld_from_kvmap(kvmap, name, "{field_name}", {field_name}, {is_optional});"#
518 ));
519 }
521 body.push_str(
522 "
523 } else if (o.type == msgpack::type::ARRAY) {
524 auto array = o.via.array; ",
525 );
526 for (index, field) in fields.iter().enumerate() {
527 if is_unit(field) {
528 continue;
529 }
530 let field_name = &field.name;
531 body.push_str(&format!(
533 r#"
534 Helpers::conv_fld_from_array(array, name, "{field_name}", {field_name}, {index});"#
535 ));
536 }
538
539 body.push_str(
540 r#"
541 } else {
542 throw_or_abort("expected MAP or ARRAY for " + name);
543 }"#,
544 );
545 body
546 });
547 }
548
549 fn generate_newtype(&mut self, name: &str) {
551 self.msgpack_pack(name, "packer.pack(value);");
552 self.msgpack_unpack(
553 name,
554 &format!(
556 r#"
557 try {{
558 o.convert(value);
559 }} catch (const msgpack::type_error&) {{
560 std::cerr << o << std::endl;
561 throw_or_abort("error converting into newtype '{name}'");
562 }}
563 "#
564 ),
565 );
567 }
568
569 fn generate_tuple(&self, _name: &str, _formats: &[Format]) {
571 unimplemented!("Until we have a tuple enum in our schema we don't need this.");
572 }
573
574 fn generate_enum(&mut self, name: &str, variants: &BTreeMap<u32, Named<VariantFormat>>) {
576 self.namespace.push(name.to_string());
578 for variant in variants.values() {
579 self.generate_variant(&variant.name, &variant.value);
580 }
581 self.namespace.pop();
582
583 self.msgpack_pack(name, &{
585 let cases = variants
586 .iter()
587 .map(|(i, v)| {
588 format!(
589 r#"
590 case {i}:
591 tag = "{}";
592 is_unit = {};
593 break;"#,
594 v.name,
595 matches!(v.value, VariantFormat::Unit)
596 )
597 })
598 .collect::<Vec<_>>()
599 .join("");
600
601 format!(
602 r#"
603 std::string tag;
604 bool is_unit;
605 switch (value.index()) {{
606 {cases}
607 default:
608 throw_or_abort("unknown enum '{name}' variant index: " + std::to_string(value.index()));
609 }}
610 if (is_unit) {{
611 packer.pack(tag);
612 }} else {{
613 std::visit([&packer, tag](const auto& arg) {{
614 packer.pack_map(1);
615 packer.pack(tag);
616 packer.pack(arg);
617 }}, value);
618 }}"#
619 )
620 });
621
622 self.msgpack_unpack(name, &{
626 let mut body = format!(
628 r#"
629
630 if (o.type != msgpack::type::object_type::MAP && o.type != msgpack::type::object_type::STR) {{
631 std::cerr << o << std::endl;
632 throw_or_abort("expected MAP or STR for enum '{name}'; got type " + std::to_string(o.type));
633 }}
634 if (o.type == msgpack::type::object_type::MAP && o.via.map.size != 1) {{
635 throw_or_abort("expected 1 entry for enum '{name}'; got " + std::to_string(o.via.map.size));
636 }}
637 std::string tag;
638 try {{
639 if (o.type == msgpack::type::object_type::MAP) {{
640 o.via.map.ptr[0].key.convert(tag);
641 }} else {{
642 o.convert(tag);
643 }}
644 }} catch(const msgpack::type_error&) {{
645 std::cerr << o << std::endl;
646 throw_or_abort("error converting tag to string for enum '{name}'");
647 }}"#
648 );
649 for (i, v) in variants {
652 let variant = &v.name;
653 body.push_str(&format!(
654 r#"
655 {}if (tag == "{variant}") {{
656 {variant} v;"#,
657 if *i == 0 { "" } else { "else " }
658 ));
659
660 if !matches!(v.value, VariantFormat::Unit) {
661 body.push_str(&format!(
663 r#"
664 try {{
665 o.via.map.ptr[0].val.convert(v);
666 }} catch (const msgpack::type_error&) {{
667 std::cerr << o << std::endl;
668 throw_or_abort("error converting into enum variant '{name}::{variant}'");
669 }}
670 "#
671 ));
672 }
674 body.push_str(
676 r#"
677 value = v;
678 }"#,
679 );
680 }
681 body.push_str(&format!(
683 r#"
684 else {{
685 std::cerr << o << std::endl;
686 throw_or_abort("unknown '{name}' enum variant: " + tag);
687 }}"#
688 ));
689 body
692 });
693 }
694
695 fn generate_variant(&mut self, name: &str, variant: &VariantFormat) {
697 match variant {
698 VariantFormat::Variable(_) => {
699 unreachable!("internal construct")
700 }
701 VariantFormat::Unit => self.generate_unit_struct(name),
702 VariantFormat::NewType(_format) => self.generate_newtype(name),
703 VariantFormat::Tuple(formats) => self.generate_tuple(name, formats),
704 VariantFormat::Struct(fields) => self.generate_struct(name, fields),
705 }
706 }
707
708 #[allow(dead_code)]
713 fn msgpack_fields(&mut self, name: &str, fields: impl Iterator<Item = String>) {
714 let fields = fields.collect::<Vec<_>>().join(", ");
715 let code = format!("MSGPACK_FIELDS({fields});");
716 self.add_code(name, &code);
717 }
718
719 fn msgpack_pack(&mut self, name: &str, body: &str) {
721 if self.config.no_pack {
722 return;
723 }
724 let code = Self::make_fn("void msgpack_pack(auto& packer) const", body);
725 self.add_code(name, &code);
726 }
727
728 fn msgpack_unpack(&mut self, name: &str, body: &str) {
730 let code = Self::make_fn("void msgpack_unpack(msgpack::object const& o)", body);
768 self.add_code(name, &code);
769 }
770
771 fn make_fn(header: &str, body: &str) -> String {
772 let body = body.trim_end();
773 if body.is_empty() {
774 format!("{header} {{}}")
775 } else if !body.contains('\n') {
776 format!("{header} {{ {body} }}")
777 } else if body.starts_with('\n') {
778 format!("{header} {{{body}\n}}")
779 } else {
780 format!("{header} {{\n{body}\n}}")
781 }
782 }
783 }
784}