1#![cfg_attr(not(test), forbid(unsafe_code))] #![cfg_attr(not(test), warn(unused_crate_dependencies, unused_extern_crates))]
9
10#[doc = include_str!("../README.md")]
11pub mod circuit;
12pub mod native_types;
13mod parser;
14mod serialization;
15
16pub use acir_field;
17pub use acir_field::{AcirField, FieldElement};
18pub use brillig;
19pub use circuit::black_box_functions::BlackBoxFunc;
20pub use circuit::opcodes::InvalidInputBitSize;
21pub use parser::parse_opcodes;
22pub use serialization::Format as SerializationFormat;
23
24#[cfg(test)]
25mod reflection {
26 use std::{
37 collections::BTreeMap,
38 fs::File,
39 hash::BuildHasher,
40 io::Write,
41 path::{Path, PathBuf},
42 };
43
44 use acir_field::{AcirField, FieldElement};
45 use brillig::{
46 BinaryFieldOp, BinaryIntOp, BitSize, BlackBoxOp, HeapValueType, IntegerBitSize,
47 MemoryAddress, Opcode as BrilligOpcode, ValueOrArray,
48 };
49 use regex::Regex;
50 use serde::{Deserialize, Serialize};
51 use serde_generate::CustomCode;
52 use serde_reflection::{
53 ContainerFormat, Format, Named, Registry, Tracer, TracerConfig, VariantFormat,
54 };
55
56 use crate::{
57 circuit::{
58 AssertionPayload, Circuit, ExpressionOrMemory, Opcode, OpcodeLocation, Program,
59 brillig::{BrilligInputs, BrilligOutputs},
60 opcodes::{BlackBoxFuncCall, BlockType, FunctionInput},
61 },
62 native_types::{Witness, WitnessMap, WitnessStack},
63 };
64
65 #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
76 struct ProgramWithoutBrillig<F: AcirField> {
77 pub functions: Vec<Circuit<F>>,
78 pub unconstrained_functions: (),
83 }
84
85 #[test]
86 fn serde_acir_cpp_codegen() {
87 let mut tracer = Tracer::new(TracerConfig::default());
88 tracer.trace_simple_type::<BlockType>().unwrap();
89 tracer.trace_simple_type::<Program<FieldElement>>().unwrap();
90 tracer.trace_simple_type::<ProgramWithoutBrillig<FieldElement>>().unwrap();
91 tracer.trace_simple_type::<Circuit<FieldElement>>().unwrap();
92 tracer.trace_simple_type::<Opcode<FieldElement>>().unwrap();
93 tracer.trace_simple_type::<OpcodeLocation>().unwrap();
94 tracer.trace_simple_type::<BinaryFieldOp>().unwrap();
95 tracer.trace_simple_type::<FunctionInput<FieldElement>>().unwrap();
96 tracer.trace_simple_type::<FunctionInput<FieldElement>>().unwrap();
97 tracer.trace_simple_type::<BlackBoxFuncCall<FieldElement>>().unwrap();
98 tracer.trace_simple_type::<BrilligInputs<FieldElement>>().unwrap();
99 tracer.trace_simple_type::<BrilligOutputs>().unwrap();
100 tracer.trace_simple_type::<BrilligOpcode<FieldElement>>().unwrap();
101 tracer.trace_simple_type::<BinaryIntOp>().unwrap();
102 tracer.trace_simple_type::<BlackBoxOp>().unwrap();
103 tracer.trace_simple_type::<ValueOrArray>().unwrap();
104 tracer.trace_simple_type::<HeapValueType>().unwrap();
105 tracer.trace_simple_type::<AssertionPayload<FieldElement>>().unwrap();
106 tracer.trace_simple_type::<ExpressionOrMemory<FieldElement>>().unwrap();
107 tracer.trace_simple_type::<BitSize>().unwrap();
108 tracer.trace_simple_type::<IntegerBitSize>().unwrap();
109 tracer.trace_simple_type::<MemoryAddress>().unwrap();
110
111 serde_cpp_codegen(
112 "Acir",
113 PathBuf::from("./codegen/acir.cpp").as_path(),
114 &tracer.registry().unwrap(),
115 CustomCode::default(),
116 );
117 }
118
119 #[test]
120 fn serde_witness_map_cpp_codegen() {
121 let mut tracer = Tracer::new(TracerConfig::default());
122 tracer.trace_simple_type::<Witness>().unwrap();
123 tracer.trace_simple_type::<WitnessMap<FieldElement>>().unwrap();
124 tracer.trace_simple_type::<WitnessStack<FieldElement>>().unwrap();
125
126 let namespace = "Witnesses";
127 let mut code = CustomCode::default();
128 code.insert(
131 vec![namespace.to_string(), "Witness".to_string()],
132 "bool operator<(Witness const& rhs) const { return value < rhs.value; }".to_string(),
133 );
134
135 serde_cpp_codegen(
136 namespace,
137 PathBuf::from("./codegen/witness.cpp").as_path(),
138 &tracer.registry().unwrap(),
139 code,
140 );
141 }
142
143 fn serde_cpp_codegen(namespace: &str, path: &Path, registry: &Registry, code: CustomCode) {
149 let old_hash = if path.is_file() {
150 let old_source = std::fs::read(path).expect("failed to read existing code");
151 let old_source = String::from_utf8(old_source).expect("old source not UTF-8");
152 Some(rustc_hash::FxBuildHasher.hash_one(&old_source))
153 } else {
154 None
155 };
156 let msgpack_code = MsgPackCodeGenerator::generate(
157 namespace,
158 registry,
159 code,
160 MsgPackCodeConfig::from_env(),
161 );
162
163 let mut source = Vec::new();
165 let config = serde_generate::CodeGeneratorConfig::new(namespace.to_string())
168 .with_encodings(vec![])
169 .with_custom_code(msgpack_code);
170 let generator = serde_generate::cpp::CodeGenerator::new(&config);
171 generator.output(&mut source, registry).expect("failed to generate C++ code");
172
173 let mut source = String::from_utf8(source).expect("not a UTF-8 string");
175 replace_throw(&mut source);
176 MsgPackCodeGenerator::add_preamble(&mut source);
177 MsgPackCodeGenerator::add_helpers(&mut source, namespace);
178 MsgPackCodeGenerator::replace_array_with_shared_ptr(&mut source);
179
180 if !should_overwrite()
181 && let Some(old_hash) = old_hash
182 {
183 let new_hash = rustc_hash::FxBuildHasher.hash_one(&source);
184 assert_eq!(new_hash, old_hash, "Serialization format has changed",);
185 }
186
187 write_to_file(source.as_bytes(), path);
188 }
189
190 fn env_flag(name: &str, default: bool) -> bool {
192 let Ok(s) = std::env::var(name) else {
193 return default;
194 };
195 match s.as_str() {
196 "1" | "true" | "yes" => true,
197 "0" | "false" | "no" => false,
198 _ => default,
199 }
200 }
201
202 fn should_overwrite() -> bool {
205 env_flag("NOIR_CODEGEN_OVERWRITE", false)
206 }
207
208 fn write_to_file(bytes: &[u8], path: &Path) -> String {
209 let display = path.display();
210
211 let parent_dir = path.parent().unwrap();
212 if !parent_dir.is_dir() {
213 std::fs::create_dir_all(parent_dir).unwrap();
214 }
215
216 let mut file = match File::create(path) {
217 Err(why) => panic!("couldn't create {display}: {why}"),
218 Ok(file) => file,
219 };
220
221 match file.write_all(bytes) {
222 Err(why) => panic!("couldn't write to {display}: {why}"),
223 Ok(_) => display.to_string(),
224 }
225 }
226
227 fn replace_throw(source: &mut String) {
233 *source = source.replace("throw serde::deserialization_error", "throw_or_abort");
234 }
235
236 struct MsgPackCodeConfig {
237 pack_compact: bool,
239 no_pack: bool,
241 }
242
243 impl MsgPackCodeConfig {
244 fn from_env() -> Self {
245 Self {
246 pack_compact: env_flag("NOIR_CODEGEN_PACK_COMPACT", true),
248 no_pack: env_flag("NOIR_CODEGEN_NO_PACK", true),
250 }
251 }
252 }
253
254 struct MsgPackCodeGenerator {
257 config: MsgPackCodeConfig,
258 namespace: Vec<String>,
259 code: CustomCode,
260 }
261
262 impl MsgPackCodeGenerator {
263 pub(crate) fn add_preamble(source: &mut String) {
265 let inc = r#"#include "serde.hpp""#;
266 let pos = source.find(inc).expect("serde.hpp missing");
267 source.insert_str(
268 pos + inc.len(),
269 "\n#include \"barretenberg/serialize/msgpack_impl.hpp\"",
270 );
271 }
272
273 pub(crate) fn add_helpers(source: &mut String, namespace: &str) {
275 let helpers = r#"
279 struct Helpers {
280 static std::map<std::string, msgpack::object const*> make_kvmap(
281 msgpack::object const& o,
282 std::string const& name
283 ) {
284 if (o.type != msgpack::type::MAP) {
285 std::cerr << o << std::endl;
286 throw_or_abort("expected MAP for " + name);
287 }
288 std::map<std::string, msgpack::object const*> kvmap;
289 for (uint32_t i = 0; i < o.via.map.size; ++i) {
290 if (o.via.map.ptr[i].key.type != msgpack::type::STR) {
291 std::cerr << o << std::endl;
292 throw_or_abort("expected STR for keys of " + name);
293 }
294 kvmap.emplace(
295 std::string(
296 o.via.map.ptr[i].key.via.str.ptr,
297 o.via.map.ptr[i].key.via.str.size),
298 &o.via.map.ptr[i].val);
299 }
300 return kvmap;
301 }
302
303 template<typename T>
304 static void conv_fld_from_kvmap(
305 std::map<std::string, msgpack::object const*> const& kvmap,
306 std::string const& struct_name,
307 std::string const& field_name,
308 T& field,
309 bool is_optional
310 ) {
311 auto it = kvmap.find(field_name);
312 if (it != kvmap.end()) {
313 try {
314 it->second->convert(field);
315 } catch (const msgpack::type_error&) {
316 std::cerr << *it->second << std::endl;
317 throw_or_abort("error converting into field " + struct_name + "::" + field_name);
318 }
319 } else if (!is_optional) {
320 throw_or_abort("missing field: " + struct_name + "::" + field_name);
321 }
322 }
323
324 template<typename T>
325 static void conv_fld_from_array(
326 msgpack::object_array const& array,
327 std::string const& struct_name,
328 std::string const& field_name,
329 T& field,
330 uint32_t index
331 ) {
332 if (index >= array.size) {
333 throw_or_abort("index out of bounds: " + struct_name + "::" + field_name + " at " + std::to_string(index));
334 }
335 auto element = array.ptr[index];
336 try {
337 element.convert(field);
338 } catch (const msgpack::type_error&) {
339 std::cerr << element << std::endl;
340 throw_or_abort("error converting into field " + struct_name + "::" + field_name);
341 }
342 }
343 };
344 "#;
345 let pos = source.find(&format!("namespace {namespace}")).expect("namespace");
347 source.insert_str(pos, &format!("namespace {namespace} {{{helpers}}}\n\n"));
348 }
349
350 fn replace_array_with_shared_ptr(source: &mut String) {
352 let re = Regex::new(r#"std::array<\s*([^,<>]+?)\s*,\s*([0-9]+)\s*>"#)
354 .expect("failed to create regex");
355
356 let fixed =
357 re.replace_all(source, "std::shared_ptr<std::array<${1}, ${2}>>").into_owned();
358
359 *source = fixed;
360 }
361
362 pub(crate) fn generate(
364 namespace: &str,
365 registry: &Registry,
366 code: CustomCode,
367 config: MsgPackCodeConfig,
368 ) -> CustomCode {
369 let mut g = Self { namespace: vec![namespace.to_string()], code, config };
370 for (name, container) in registry {
371 g.generate_container(name, container);
372 }
373 g.code
374 }
375
376 fn add_code(&mut self, name: &str, code: &str) {
378 let mut ns = self.namespace.clone();
379 ns.push(name.to_string());
380 let c = self.code.entry(ns).or_default();
381 if !c.is_empty() && code.contains('\n') {
382 c.push('\n');
383 }
384 c.push_str(code);
385 c.push('\n');
386 }
387
388 fn generate_container(&mut self, name: &str, container: &ContainerFormat) {
389 use serde_reflection::ContainerFormat::*;
390 match container {
391 UnitStruct => {
392 self.generate_unit_struct(name);
393 }
394 NewTypeStruct(_format) => {
395 self.generate_newtype(name);
396 }
397 TupleStruct(formats) => {
398 self.generate_tuple(name, formats);
399 }
400 Struct(fields) => {
401 self.generate_struct(name, fields);
402 }
403 Enum(variants) => {
404 self.generate_enum(name, variants);
405 }
406 }
407 }
408
409 fn generate_unit_struct(&mut self, name: &str) {
411 self.msgpack_pack(name, "");
415 self.msgpack_unpack(name, "");
416 }
417
418 fn generate_struct(&mut self, name: &str, fields: &[Named<Format>]) {
420 fn is_unit(field: &Named<Format>) -> bool {
437 matches!(field.value, Format::Unit)
438 }
439
440 let non_unit_field_count = fields.iter().filter(|f| !is_unit(f)).count();
441
442 self.msgpack_pack(name, &{
443 if self.config.pack_compact {
444 let mut body = format!(
446 "
447 packer.pack_array({});",
448 fields.len()
449 );
450 for field in fields {
451 let field_name = &field.name;
452 body.push_str(&format!(
453 r#"
454 packer.pack({field_name});"#
455 ));
456 }
457 body
458 } else {
459 let mut body = format!(
461 "
462 packer.pack_map({non_unit_field_count});",
463 );
464 for field in fields {
465 if is_unit(field) {
466 continue;
467 }
468 let field_name = &field.name;
469 body.push_str(&format!(
470 r#"
471 packer.pack(std::make_pair("{field_name}", {field_name}));"#
472 ));
473 }
474 body
475 }
476 });
477
478 self.msgpack_unpack(name, &{
479 let mut body = format!(
483 r#"
484 std::string name = "{name}";
485 if (o.type == msgpack::type::MAP) {{
486 auto kvmap = Helpers::make_kvmap(o, name);"#
487 );
488 for field in fields {
490 if is_unit(field) {
491 continue;
492 }
493 let field_name = &field.name;
494 let is_optional = matches!(field.value, Format::Option(_));
495 body.push_str(&format!(
497 r#"
498 Helpers::conv_fld_from_kvmap(kvmap, name, "{field_name}", {field_name}, {is_optional});"#
499 ));
500 }
502 body.push_str(
503 "
504 } else if (o.type == msgpack::type::ARRAY) {
505 auto array = o.via.array; ",
506 );
507 for (index, field) in fields.iter().enumerate() {
508 if is_unit(field) {
509 continue;
510 }
511 let field_name = &field.name;
512 body.push_str(&format!(
514 r#"
515 Helpers::conv_fld_from_array(array, name, "{field_name}", {field_name}, {index});"#
516 ));
517 }
519
520 body.push_str(
521 r#"
522 } else {
523 throw_or_abort("expected MAP or ARRAY for " + name);
524 }"#,
525 );
526 body
527 });
528 }
529
530 fn generate_newtype(&mut self, name: &str) {
532 self.msgpack_pack(name, "packer.pack(value);");
533 self.msgpack_unpack(
534 name,
535 &format!(
537 r#"
538 try {{
539 o.convert(value);
540 }} catch (const msgpack::type_error&) {{
541 std::cerr << o << std::endl;
542 throw_or_abort("error converting into newtype '{name}'");
543 }}
544 "#
545 ),
546 );
548 }
549
550 fn generate_tuple(&self, _name: &str, _formats: &[Format]) {
552 unimplemented!("Until we have a tuple enum in our schema we don't need this.");
553 }
554
555 fn generate_enum(&mut self, name: &str, variants: &BTreeMap<u32, Named<VariantFormat>>) {
557 self.namespace.push(name.to_string());
559 for variant in variants.values() {
560 self.generate_variant(&variant.name, &variant.value);
561 }
562 self.namespace.pop();
563
564 self.msgpack_pack(name, &{
566 let cases = variants
567 .iter()
568 .map(|(i, v)| {
569 format!(
570 r#"
571 case {i}:
572 tag = "{}";
573 is_unit = {};
574 break;"#,
575 v.name,
576 matches!(v.value, VariantFormat::Unit)
577 )
578 })
579 .collect::<Vec<_>>()
580 .join("");
581
582 format!(
583 r#"
584 std::string tag;
585 bool is_unit;
586 switch (value.index()) {{
587 {cases}
588 default:
589 throw_or_abort("unknown enum '{name}' variant index: " + std::to_string(value.index()));
590 }}
591 if (is_unit) {{
592 packer.pack(tag);
593 }} else {{
594 std::visit([&packer, tag](const auto& arg) {{
595 packer.pack_map(1);
596 packer.pack(tag);
597 packer.pack(arg);
598 }}, value);
599 }}"#
600 )
601 });
602
603 self.msgpack_unpack(name, &{
607 let mut body = format!(
609 r#"
610
611 if (o.type != msgpack::type::object_type::MAP && o.type != msgpack::type::object_type::STR) {{
612 std::cerr << o << std::endl;
613 throw_or_abort("expected MAP or STR for enum '{name}'; got type " + std::to_string(o.type));
614 }}
615 if (o.type == msgpack::type::object_type::MAP && o.via.map.size != 1) {{
616 throw_or_abort("expected 1 entry for enum '{name}'; got " + std::to_string(o.via.map.size));
617 }}
618 std::string tag;
619 try {{
620 if (o.type == msgpack::type::object_type::MAP) {{
621 o.via.map.ptr[0].key.convert(tag);
622 }} else {{
623 o.convert(tag);
624 }}
625 }} catch(const msgpack::type_error&) {{
626 std::cerr << o << std::endl;
627 throw_or_abort("error converting tag to string for enum '{name}'");
628 }}"#
629 );
630 for (i, v) in variants.iter() {
633 let variant = &v.name;
634 body.push_str(&format!(
635 r#"
636 {}if (tag == "{variant}") {{
637 {variant} v;"#,
638 if *i == 0 { "" } else { "else " }
639 ));
640
641 if !matches!(v.value, VariantFormat::Unit) {
642 body.push_str(&format!(
644 r#"
645 try {{
646 o.via.map.ptr[0].val.convert(v);
647 }} catch (const msgpack::type_error&) {{
648 std::cerr << o << std::endl;
649 throw_or_abort("error converting into enum variant '{name}::{variant}'");
650 }}
651 "#
652 ));
653 }
655 body.push_str(
657 r#"
658 value = v;
659 }"#,
660 );
661 }
662 body.push_str(&format!(
664 r#"
665 else {{
666 std::cerr << o << std::endl;
667 throw_or_abort("unknown '{name}' enum variant: " + tag);
668 }}"#
669 ));
670 body
673 });
674 }
675
676 fn generate_variant(&mut self, name: &str, variant: &VariantFormat) {
678 match variant {
679 VariantFormat::Variable(_) => {
680 unreachable!("internal construct")
681 }
682 VariantFormat::Unit => self.generate_unit_struct(name),
683 VariantFormat::NewType(_format) => self.generate_newtype(name),
684 VariantFormat::Tuple(formats) => self.generate_tuple(name, formats),
685 VariantFormat::Struct(fields) => self.generate_struct(name, fields),
686 }
687 }
688
689 #[allow(dead_code)]
694 fn msgpack_fields(&mut self, name: &str, fields: impl Iterator<Item = String>) {
695 let fields = fields.collect::<Vec<_>>().join(", ");
696 let code = format!("MSGPACK_FIELDS({fields});");
697 self.add_code(name, &code);
698 }
699
700 fn msgpack_pack(&mut self, name: &str, body: &str) {
702 if self.config.no_pack {
703 return;
704 }
705 let code = Self::make_fn("void msgpack_pack(auto& packer) const", body);
706 self.add_code(name, &code);
707 }
708
709 fn msgpack_unpack(&mut self, name: &str, body: &str) {
711 let code = Self::make_fn("void msgpack_unpack(msgpack::object const& o)", body);
749 self.add_code(name, &code);
750 }
751
752 fn make_fn(header: &str, body: &str) -> String {
753 let body = body.trim_end();
754 if body.is_empty() {
755 format!("{header} {{}}")
756 } else if !body.contains('\n') {
757 format!("{header} {{ {body} }}")
758 } else if body.starts_with('\n') {
759 format!("{header} {{{body}\n}}")
760 } else {
761 format!("{header} {{\n{body}\n}}")
762 }
763 }
764 }
765}