msgpack_tagged/
registry.rs

1//! Local registry of types participating in tagged-map serialization.
2//!
3//! Built once per encode/decode call by walking the type graph from a top-level
4//! type via [`MsgpackTagged::register_into`]. The wrapper Serializer/Deserializer
5//! consults this registry to translate between serde field names and integer tags.
6//!
7//! ## Wire-shape model
8//!
9//! A tagged type is one of two algebraic shapes:
10//!
11//! * A [`Product`] — a fixed list of `(tag, name)` field entries. Used for
12//!   structs, tuple structs, and (recursively) for an enum variant's payload.
13//! * A [`Sum`] — a discriminated union of [`Variant`]s, each carrying its own
14//!   `Product` payload.
15//!
16//! Both shapes are unified under [`Tagged`], which is the only thing the trait
17//! exposes (via the `TAGGED` associated const). The registry stores `Tagged`
18//! values keyed by serde name and routes wrapper code through the matching arm.
19//!
20//! Every type used here ([`Tagged`], [`Product`], [`Variant`], [`Sum`]) is
21//! `Copy` with public fields — they're built directly in `const` context by
22//! the derive macro and read flatly from the trait, so there's no
23//! encapsulated state to protect.
24
25use std::any::TypeId;
26use std::collections::HashMap;
27
28use crate::{MsgpackTagged, Tag};
29
30/// The wire shape of a tagged type, used both at the top level (in
31/// `MsgpackTagged::TAGGED`) and recursively inside variant payloads.
32#[derive(Clone, Copy, Debug)]
33pub enum Tagged {
34    /// A struct, tuple struct, or any other product-shaped wire type.
35    Product(Product),
36    /// An enum: a discriminated union of variants.
37    Sum(Sum),
38}
39
40impl Tagged {
41    /// Borrow the inner [`Product`] if this is a product shape.
42    pub fn as_product(self) -> Option<Product> {
43        match self {
44            Tagged::Product(p) => Some(p),
45            Tagged::Sum(_) => None,
46        }
47    }
48
49    /// Borrow the inner [`Sum`] if this is a sum shape.
50    pub fn as_sum(self) -> Option<Sum> {
51        match self {
52            Tagged::Sum(s) => Some(s),
53            Tagged::Product(_) => None,
54        }
55    }
56
57    /// Empty [Tagged::Product] used for primitives and _newtypes_.
58    pub const fn empty_product() -> Self {
59        Self::Product(Product::empty())
60    }
61}
62
63/// A product type — a fixed list of named, integer-tagged fields. Used for
64/// top-level structs/tuple structs *and* for an enum variant's payload (a
65/// variant is structurally just a struct hung off a tag).
66///
67/// `fields` is in tag-ascending order (the canonical wire order). `reserved`
68/// lists tags previously used by this product and now retired — purely
69/// compile-time metadata that prevents reuse, never affects decode behavior.
70/// `allow_unknown_tags` opts the decoder into silently skipping fields whose
71/// tag isn't in `fields` or `reserved`. Per-field wire-tolerance (i.e. "fill
72/// `T::default()` when this tag is missing") is **not** modeled here — it's
73/// expressed on the user side via serde-derive's `#[serde(default)]`, which
74/// is what actually performs the substitution at decode time.
75///
76/// `tag_order_matches_source` says the user's source-declaration order is
77/// already tag-ascending — i.e. the order serde-derive will call
78/// `serialize_field` in matches the canonical wire order. Set by the macro
79/// at derive time. The encoder uses it to skip the buffer-and-sort flush
80/// under the `Array` strategy when source order is already correct, saving
81/// a per-field `Vec<u8>` allocation. Under `Tagged` the encoder always
82/// writes direct (no canonical-byte-order promise), so this flag is
83/// unused there.
84#[derive(Clone, Copy, Debug)]
85pub struct Product {
86    pub fields: &'static [(Tag, &'static str)],
87    pub reserved: &'static [Tag],
88    pub allow_unknown_tags: bool,
89    pub tag_order_matches_source: bool,
90}
91
92impl Product {
93    /// Look up a field's tag by its serde name. O(N) over `fields` —
94    /// acceptable for the small (typically 3-30) field counts of ACIR types;
95    /// if a profile ever shows this hot, the registry can precompute
96    /// HashMap views.
97    pub fn tag_for(self, field_name: &str) -> Option<Tag> {
98        self.fields.iter().find(|(_, name)| *name == field_name).map(|(t, _)| *t)
99    }
100
101    /// Look up a field's serde name by its tag.
102    pub fn field_for(self, tag: Tag) -> Option<&'static str> {
103        self.fields.iter().find(|(t, _)| *t == tag).map(|(_, name)| *name)
104    }
105
106    /// Whether `tag` is in the product's reserved list (a retired tag from
107    /// an older schema version).
108    pub fn is_reserved(self, tag: Tag) -> bool {
109        self.reserved.contains(&tag)
110    }
111
112    /// Empty [Product] used for primitives and _newtypes_.
113    pub const fn empty() -> Self {
114        // No fields ⇒ trivially monotonic (no order to violate).
115        Self {
116            fields: &[],
117            reserved: &[],
118            allow_unknown_tags: false,
119            tag_order_matches_source: true,
120        }
121    }
122}
123
124/// The shape of an enum variant's payload, used by the wrapper to decide how
125/// to encode/decode the value carried under the variant tag.
126///
127/// `Unit` and `Newtype` both have an empty `payload` `Product` — the
128/// distinction lives in this discriminator. A unit variant carries no value
129/// at all (the wire emits the variant tag with no payload), while a newtype
130/// variant passes the inner value through directly under the variant tag
131/// (zero-cost wrapper, no field-level tag/key allocated for the inner value).
132///
133/// `Tuple` and `Struct` variants both carry their fields in the variant's
134/// `payload` `Product`, but differ in addressing on the wire: tuple variants
135/// use positional names ("0", "1", …) and struct variants use field idents.
136/// Tuple variants with a single explicitly tagged field still count as
137/// `Tuple`, not `Newtype` — the explicit `#[tag(N)]` is what asks for a
138/// field-level tag wrapping.
139#[derive(Clone, Copy, Debug, PartialEq, Eq)]
140pub enum VariantKind {
141    Unit,
142    Newtype,
143    Tuple,
144    Struct,
145}
146
147/// One variant of a sum type. Its payload is a single [`Product`] (possibly
148/// with no fields for unit and newtype variants — see [`VariantKind`] for the
149/// discriminator that distinguishes them).
150#[derive(Clone, Copy, Debug)]
151pub struct Variant {
152    pub tag: Tag,
153    pub name: &'static str,
154    pub kind: VariantKind,
155    pub payload: Product,
156}
157
158/// A sum type — a discriminated union of [`Variant`]s.
159///
160/// `reserved` lists retired *variant* tags. Like `Product::reserved`, this is
161/// always a compile-time tag-reuse guard; whether the runtime decoder routes
162/// such tags to a fallback variant is controlled by `on_reserved_tag`.
163///
164/// `on_reserved_tag` and `on_unknown_tag` opt into runtime-lenient decode of
165/// variant tags. Unlike products' `allow_unknown_tags` (which just skips an
166/// entry), sums can't skip a discriminator — the value is the discriminator —
167/// so the tolerance is expressed as "route to a designated fallback variant,
168/// discarding the payload":
169///
170/// * `on_reserved_tag` — when set, the wire tag of the unit variant that
171///   acts as the backward-compat fallback. The macro fills it in iff a
172///   variant in the source carries `#[tagged(on_reserved)]`. On decode,
173///   any wire tag in `reserved` is routed here (payload discarded).
174/// * `on_unknown_tag` — same shape, but for forward-compat: a variant
175///   marked `#[tagged(on_unknown)]` catches any wire tag that's neither in
176///   `variants` nor in `reserved`. **More dangerous** than `on_reserved`:
177///   silently swallows real corruption alongside future-version tags, so
178///   opt in only when the fallback variant is a safe semantic substitute
179///   for "anything I don't recognize" (e.g. metadata-bearing
180///   `InlineType`-shaped types — definitely not `BrilligOpcode`-shaped
181///   ones, where an unknown discriminator means we can't execute the
182///   program).
183///
184/// A single variant may carry both `#[tagged(on_reserved)]` and
185/// `#[tagged(on_unknown)]` when the user wants the unified-catch-all
186/// behavior; in that case both fields point at the same tag.
187#[derive(Clone, Copy, Debug)]
188pub struct Sum {
189    pub variants: &'static [Variant],
190    pub reserved: &'static [Tag],
191    pub on_reserved_tag: Option<Tag>,
192    pub on_unknown_tag: Option<Tag>,
193}
194
195impl Sum {
196    /// Look up a variant's metadata by its serde name.
197    pub fn variant_for(self, variant_name: &str) -> Option<Variant> {
198        self.variants.iter().find(|v| v.name == variant_name).copied()
199    }
200
201    /// Whether `tag` is in the sum's reserved list (a retired variant tag).
202    pub fn is_reserved(self, tag: Tag) -> bool {
203        self.reserved.contains(&tag)
204    }
205}
206
207/// A registered type's metadata. Stores only the type's [`Tagged`] shape
208/// alongside a `TypeId` used to detect serde-name collisions between
209/// different Rust types.
210#[derive(Debug)]
211pub struct Entry {
212    type_id: TypeId,
213    tagged: Tagged,
214}
215
216impl Entry {
217    /// The type's wire shape — match on it to dispatch product vs. sum.
218    pub fn tagged(&self) -> Tagged {
219        self.tagged
220    }
221
222    /// The registered Rust type's [`TypeId`]. Used by the serializer to
223    /// bridge serde's `&str` name (received at `serialize_struct` time)
224    /// to the `TypeId`-keyed per-type strategy overrides.
225    pub fn type_id(&self) -> TypeId {
226        self.type_id
227    }
228}
229
230/// The basename component of `std::any::type_name::<T>()` — module path
231/// stripped, generic parameters dropped. Used by the strategy-override
232/// machinery to look up registered types by the same serde name that
233/// `#[serde(rename = "...")]` (or the bare type ident) maps to.
234///
235/// Examples (illustrative — actual results depend on the compiler):
236/// * `Circuit<FieldElement>` → `"Circuit"`
237/// * `acir::circuit::Program<acir_field::FieldElement>` → `"Program"`
238/// * `Vec<u32>` → `"Vec"`
239/// * `u32` → `"u32"`
240///
241/// Caveat: a shadow-DTO type with `#[serde(rename = "Public")]` has
242/// `type_name` = `"…::PublicWire"` (the Rust type) but registers under
243/// `"Public"` (the serde name). The public type that delegates via
244/// `#[tagged(via(PublicWire<F>))]` has `type_name` = `"…::Public"`,
245/// which lines up. So passing the public type to
246/// `with_strategy::<Public<F>>` works; passing the wire DTO directly
247/// (`with_strategy::<PublicWire<F>>`) would silently miss.
248pub fn type_name_basename<T: ?Sized>() -> &'static str {
249    let full: &'static str = std::any::type_name::<T>();
250    // Strip generic parameters: everything from the first `<` onward.
251    let no_generics: &'static str = full.split_once('<').map_or(full, |(head, _)| head);
252    // Strip module path: everything before and including the last `::`.
253    no_generics.rsplit_once("::").map_or(no_generics, |(_, tail)| tail)
254}
255
256/// A registry of types participating in tagged-map serialization.
257#[derive(Default, Debug)]
258pub struct TagRegistry {
259    entries: HashMap<&'static str, Entry>,
260}
261
262impl TagRegistry {
263    /// Construct an empty registry.
264    pub fn new() -> Self {
265        Self::default()
266    }
267
268    /// Construct a registry by starting the type-graph walk at `T`. Calls
269    /// `T::register_into` against a fresh registry, which registers `T`
270    /// itself and then recurses through every reachable tagged field/variant
271    /// type. The standard one-shot way to build a registry for a top-level
272    /// value about to be encoded.
273    ///
274    /// ```ignore
275    /// let registry = TagRegistry::from_type::<Program>();
276    /// ```
277    pub fn from_type<T: ?Sized + MsgpackTagged>() -> Self {
278        let mut reg = Self::new();
279        T::register_into(&mut reg);
280        reg
281    }
282
283    /// Whether `name` corresponds to a registered serde name. Used by
284    /// [`crate::Serializer::with_strategy`] to fail fast when a strategy
285    /// override targets a name the registry never saw — almost always a
286    /// type-graph miss bug. Pair with [`type_name_basename`] when starting
287    /// from a Rust type:
288    ///
289    /// ```ignore
290    /// registry.contains(type_name_basename::<Circuit<F>>())
291    /// ```
292    pub fn contains(&self, name: &str) -> bool {
293        self.entries.contains_key(name)
294    }
295
296    /// Register a type under its serde name.
297    ///
298    /// Returns `true` if this type was newly inserted — the caller (typically a
299    /// macro-generated `register_into` body) should then recurse into the type's
300    /// field types. Returns `false` if the same type was already registered,
301    /// short-circuiting the recursive walk.
302    ///
303    /// **Panics** if a *different* Rust type is already registered under the same
304    /// `name` — that signals a real serde-name collision, which the user must
305    /// resolve with `#[serde(rename = "...")]` on one of the types.
306    pub fn try_insert<T: MsgpackTagged>(&mut self, name: &'static str) -> bool {
307        use std::collections::hash_map::Entry as HashEntry;
308        match self.entries.entry(name) {
309            HashEntry::Vacant(slot) => {
310                slot.insert(Entry { type_id: TypeId::of::<T>(), tagged: T::TAGGED });
311                true
312            }
313            HashEntry::Occupied(slot) => {
314                if slot.get().type_id == TypeId::of::<T>() {
315                    false
316                } else {
317                    panic!(
318                        "MsgpackTagged registry collision: serde name {name:?} is registered for two different Rust types — disambiguate with #[serde(rename = \"...\")] on one of them"
319                    );
320                }
321            }
322        }
323    }
324
325    /// Look up a type's entry by serde name. Returns `None` if the type was
326    /// never registered — the wrapper decides whether that's an error
327    /// (encode-side) or a clean failure to decode (decode-side).
328    pub fn get(&self, name: &str) -> Option<&Entry> {
329        self.entries.get(name)
330    }
331
332    pub fn len(&self) -> usize {
333        self.entries.len()
334    }
335
336    pub fn is_empty(&self) -> bool {
337        self.entries.is_empty()
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    /// Hand-written struct-shaped impl exercising every `Product` field.
346    struct Foo;
347    impl MsgpackTagged for Foo {
348        const TAGGED: Tagged = Tagged::Product(Product {
349            fields: &[(0, "a"), (1, "b")],
350            reserved: &[3],
351            allow_unknown_tags: true,
352            tag_order_matches_source: true,
353        });
354        fn register_into(_reg: &mut TagRegistry) {}
355    }
356
357    /// Hand-written struct with all `Product` extras at their defaults.
358    struct Bar;
359    impl MsgpackTagged for Bar {
360        const TAGGED: Tagged = Tagged::Product(Product {
361            fields: &[(0, "x")],
362            reserved: &[],
363            allow_unknown_tags: false,
364            tag_order_matches_source: true,
365        });
366        fn register_into(_reg: &mut TagRegistry) {}
367    }
368
369    /// Hand-written sum-shaped impl: stand-in for what the derive macro will
370    /// emit for `enum Choice { #[tag(0)] Empty, #[tag(1)] Pair { #[tag(0)] a, #[tag(2)] b } }`.
371    struct Choice;
372    impl MsgpackTagged for Choice {
373        const TAGGED: Tagged = Tagged::Sum(Sum {
374            variants: &[
375                Variant {
376                    tag: 0,
377                    name: "Empty",
378                    kind: VariantKind::Unit,
379                    payload: Product::empty(),
380                },
381                Variant {
382                    tag: 1,
383                    name: "Pair",
384                    kind: VariantKind::Struct,
385                    payload: Product {
386                        fields: &[(0, "a"), (2, "b")],
387                        reserved: &[],
388                        allow_unknown_tags: false,
389                        tag_order_matches_source: true,
390                    },
391                },
392            ],
393            reserved: &[5],
394            on_reserved_tag: None,
395            on_unknown_tag: None,
396        });
397        fn register_into(_reg: &mut TagRegistry) {}
398    }
399
400    /// Hand-written sum exercising both fallback markers together. Mirrors
401    /// the derive-macro emission for an enum like
402    /// `#[tagged(reserved(7))] enum Lenient { #[tag(0)] A, #[tag(1)] B, #[tag(2)] #[tagged(on_reserved, on_unknown)] Other }`.
403    struct Lenient;
404    impl MsgpackTagged for Lenient {
405        const TAGGED: Tagged = Tagged::Sum(Sum {
406            variants: &[
407                Variant { tag: 0, name: "A", kind: VariantKind::Unit, payload: Product::empty() },
408                Variant { tag: 1, name: "B", kind: VariantKind::Unit, payload: Product::empty() },
409                Variant {
410                    tag: 2,
411                    name: "Other",
412                    kind: VariantKind::Unit,
413                    payload: Product::empty(),
414                },
415            ],
416            reserved: &[7],
417            on_reserved_tag: Some(2),
418            on_unknown_tag: Some(2),
419        });
420        fn register_into(_reg: &mut TagRegistry) {}
421    }
422
423    fn product_of<T: MsgpackTagged>() -> Product {
424        T::TAGGED.as_product().expect("expected a product-shaped type")
425    }
426
427    fn sum_of<T: MsgpackTagged>() -> Sum {
428        T::TAGGED.as_sum().expect("expected a sum-shaped type")
429    }
430
431    /// Self-registering fixture: unlike `Foo` / `Choice`, this fixture's
432    /// `register_into` actually populates the registry — exercises the
433    /// `TagRegistry::of::<T>` helper end-to-end.
434    struct SelfRegistering;
435    impl MsgpackTagged for SelfRegistering {
436        const TAGGED: Tagged = Tagged::empty_product();
437        fn register_into(reg: &mut TagRegistry) {
438            reg.try_insert::<Self>("SelfRegistering");
439        }
440    }
441
442    #[test]
443    fn from_type_walks_the_type_graph_from_a_typed_entry_point() {
444        let reg = TagRegistry::from_type::<SelfRegistering>();
445        assert!(
446            reg.get("SelfRegistering").is_some(),
447            "SelfRegistering's `register_into` should run via `TagRegistry::from_type`",
448        );
449    }
450
451    #[test]
452    fn try_insert_returns_true_on_first_insert() {
453        let mut reg = TagRegistry::new();
454        assert!(reg.try_insert::<Foo>("Foo"));
455        assert_eq!(reg.len(), 1);
456    }
457
458    #[test]
459    fn try_insert_returns_false_on_idempotent_reinsert() {
460        let mut reg = TagRegistry::new();
461        assert!(reg.try_insert::<Foo>("Foo"));
462        assert!(!reg.try_insert::<Foo>("Foo"));
463        assert_eq!(reg.len(), 1);
464    }
465
466    #[test]
467    #[should_panic(expected = "registry collision")]
468    fn try_insert_panics_on_name_collision_between_different_types() {
469        let mut reg = TagRegistry::new();
470        reg.try_insert::<Foo>("Same");
471        reg.try_insert::<Bar>("Same");
472    }
473
474    #[test]
475    fn distinct_names_for_different_types_coexist() {
476        let mut reg = TagRegistry::new();
477        assert!(reg.try_insert::<Foo>("Foo"));
478        assert!(reg.try_insert::<Bar>("Bar"));
479        assert_eq!(reg.len(), 2);
480    }
481
482    #[test]
483    fn get_returns_entry_with_the_types_tagged_shape() {
484        let mut reg = TagRegistry::new();
485        reg.try_insert::<Foo>("Foo");
486        let entry = reg.get("Foo").unwrap();
487        let p = entry.tagged().as_product().unwrap();
488        assert_eq!(p.fields, &[(0, "a"), (1, "b")]);
489        assert_eq!(p.reserved, &[3]);
490        assert!(p.allow_unknown_tags);
491    }
492
493    #[test]
494    fn get_returns_none_for_unknown_name() {
495        let reg = TagRegistry::new();
496        assert!(reg.get("Anything").is_none());
497    }
498
499    #[test]
500    fn product_tag_for_finds_known_fields() {
501        let p = product_of::<Foo>();
502        assert_eq!(p.tag_for("a"), Some(0));
503        assert_eq!(p.tag_for("b"), Some(1));
504        assert_eq!(p.tag_for("missing"), None);
505    }
506
507    #[test]
508    fn product_field_for_finds_known_tags() {
509        let p = product_of::<Foo>();
510        assert_eq!(p.field_for(0), Some("a"));
511        assert_eq!(p.field_for(1), Some("b"));
512        assert_eq!(p.field_for(99), None);
513    }
514
515    #[test]
516    fn product_is_reserved_only_for_listed_tags() {
517        let p = product_of::<Foo>();
518        assert!(p.is_reserved(3));
519        assert!(!p.is_reserved(0));
520        assert!(!p.is_reserved(99));
521    }
522
523    #[test]
524    fn empty_registry() {
525        let reg = TagRegistry::new();
526        assert!(reg.is_empty());
527        assert_eq!(reg.len(), 0);
528    }
529
530    #[test]
531    fn as_product_returns_none_for_sum_shapes() {
532        assert!(<Choice as MsgpackTagged>::TAGGED.as_product().is_none());
533    }
534
535    #[test]
536    fn as_sum_returns_none_for_product_shapes() {
537        assert!(<Foo as MsgpackTagged>::TAGGED.as_sum().is_none());
538        assert!(<Bar as MsgpackTagged>::TAGGED.as_sum().is_none());
539    }
540
541    #[test]
542    fn sum_variants_propagate_from_trait_const_to_entry() {
543        let mut reg = TagRegistry::new();
544        reg.try_insert::<Choice>("Choice");
545        let s = reg.get("Choice").unwrap().tagged().as_sum().unwrap();
546        assert_eq!(s.variants.len(), 2);
547        assert_eq!(s.variants[0].name, "Empty");
548        assert_eq!(s.variants[1].name, "Pair");
549    }
550
551    #[test]
552    fn sum_variant_for_finds_variant_by_name() {
553        let s = sum_of::<Choice>();
554        let pair = s.variant_for("Pair").expect("`Pair` variant exists");
555        assert_eq!(pair.tag, 1);
556        assert_eq!(pair.payload.fields, &[(0, "a"), (2, "b")]);
557        assert!(s.variant_for("Missing").is_none());
558    }
559
560    #[test]
561    fn variant_payload_lookups_resolve_payload_field_tags() {
562        let pair = sum_of::<Choice>().variant_for("Pair").unwrap();
563        assert_eq!(pair.payload.tag_for("a"), Some(0));
564        assert_eq!(pair.payload.tag_for("b"), Some(2));
565        assert_eq!(pair.payload.tag_for("missing"), None);
566        assert_eq!(pair.payload.field_for(0), Some("a"));
567        assert_eq!(pair.payload.field_for(2), Some("b"));
568        assert_eq!(pair.payload.field_for(99), None);
569    }
570
571    /// Unit variants have an empty `fields` slice — the wrapper can rely
572    /// on this to short-circuit field-table lookups.
573    #[test]
574    #[allow(clippy::const_is_empty)]
575    fn unit_variants_have_empty_field_table() {
576        let empty = sum_of::<Choice>().variant_for("Empty").unwrap();
577        assert!(empty.payload.fields.is_empty());
578    }
579
580    #[test]
581    fn sum_is_reserved_only_for_listed_variant_tags() {
582        let s = sum_of::<Choice>();
583        assert!(s.is_reserved(5));
584        assert!(!s.is_reserved(0));
585        assert!(!s.is_reserved(99));
586    }
587
588    /// Both fallback-tag slots default to `None` — strict decode unless
589    /// the type opts in via a variant-level marker.
590    #[test]
591    #[allow(clippy::assertions_on_constants)]
592    fn sum_default_decode_policy_is_strict() {
593        let s = sum_of::<Choice>();
594        assert!(s.on_reserved_tag.is_none());
595        assert!(s.on_unknown_tag.is_none());
596    }
597
598    #[test]
599    fn sum_decode_policy_flags_propagate_when_set() {
600        let s = sum_of::<Lenient>();
601        assert_eq!(s.on_reserved_tag, Some(2));
602        assert_eq!(s.on_unknown_tag, Some(2));
603    }
604}