brillig/
lengths.rs

1use std::{
2    iter::Sum,
3    ops::{Add, AddAssign, Div, Mul},
4};
5
6use msgpack_tagged::MsgpackTagged;
7use serde::{Deserialize, Serialize};
8
9/// Represents the length of an array or vector as seen from a user's perspective.
10/// For example in the array `[(u8, u16, [u32; 4]); 8]`, the semantic length is 8.
11#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
12#[derive(Serialize, Deserialize, MsgpackTagged)]
13pub struct SemanticLength(pub u32);
14
15impl SemanticLength {
16    pub fn to_usize(self) -> usize {
17        assert_usize(self.0)
18    }
19}
20
21impl Add<SemanticLength> for SemanticLength {
22    type Output = SemanticLength;
23
24    /// Computes the sum of two semantic lengths.
25    fn add(self, rhs: SemanticLength) -> Self::Output {
26        SemanticLength(self.0 + rhs.0)
27    }
28}
29
30impl AddAssign for SemanticLength {
31    /// Adds another semantic length to this one.
32    fn add_assign(&mut self, rhs: Self) {
33        self.0 += rhs.0;
34    }
35}
36
37impl Mul<ElementTypesLength> for SemanticLength {
38    type Output = SemiFlattenedLength;
39
40    /// Computes the semi-flattened length by multiplying the semantic length
41    /// by the element types length.
42    fn mul(self, rhs: ElementTypesLength) -> Self::Output {
43        SemiFlattenedLength(self.0 * rhs.0)
44    }
45}
46
47impl std::fmt::Display for SemanticLength {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{}", self.0)
50    }
51}
52
53/// Represents the number of types of a single element inside a vector or array, without
54/// taking into account the vector or array length.
55/// For example, in the array `[(u8, u16, [u32; 4]); 8]`, the element types length is 3:
56/// 1. u8
57/// 2. u16
58/// 3. [u32; 4]
59#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
60pub struct ElementTypesLength(pub u32);
61
62impl ElementTypesLength {
63    pub fn to_usize(self) -> usize {
64        assert_usize(self.0)
65    }
66}
67
68impl Mul<SemanticLength> for ElementTypesLength {
69    type Output = SemiFlattenedLength;
70
71    /// Computes the semi-flattened length by multiplying the semantic length
72    /// by the element types length.
73    fn mul(self, rhs: SemanticLength) -> Self::Output {
74        SemiFlattenedLength(self.0 * rhs.0)
75    }
76}
77
78impl Mul<ElementsFlattenedLength> for SemanticLength {
79    type Output = FlattenedLength;
80
81    /// Computes the flattened length by multiplying the semantic length
82    /// by the elements flattened length.
83    fn mul(self, rhs: ElementsFlattenedLength) -> Self::Output {
84        FlattenedLength(self.0 * rhs.0)
85    }
86}
87
88impl std::fmt::Display for ElementTypesLength {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        write!(f, "{}", self.0)
91    }
92}
93
94/// Represents the number of value/memory slots required to represent an array or vector.
95/// The semi-flattened length can be computed by multiplying the semantic length by
96/// the element types length.
97///
98/// For example in the array `[(u8, u16, [u32; 4]); 8]`:
99/// - The semantic length is 8
100/// - The element types length is 3
101/// - The semi-flattened length is 24 (8 * 3)
102///
103/// The reason the semi-flattened length is required, and different than the semantic length,
104/// is that in our SSA tuples are flattened so the number of value slots needed to represent an
105/// array is different than the semantic length
106///
107/// Note that this is different from the fully flattened length, which would be 8 * (1 + 1 + 4) = 48.
108#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
109#[derive(Serialize, Deserialize, MsgpackTagged)]
110#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
111pub struct SemiFlattenedLength(pub u32);
112
113impl SemiFlattenedLength {
114    pub fn to_usize(self) -> usize {
115        assert_usize(self.0)
116    }
117}
118
119impl std::fmt::Display for SemiFlattenedLength {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(f, "{}", self.0)
122    }
123}
124
125impl Div<ElementTypesLength> for SemiFlattenedLength {
126    type Output = SemanticLength;
127
128    fn div(self, rhs: ElementTypesLength) -> Self::Output {
129        if !self.0.is_multiple_of(rhs.0) {
130            panic!(
131                "Division of SemiFlattenedLength {} by ElementTypesLength {} has remainder",
132                self.0, rhs.0
133            );
134        }
135        SemanticLength(self.0 / rhs.0)
136    }
137}
138
139/// Represents the total number of fields required to represent a single entry of an array or vector.
140/// For example in the array `[(u8, u16, [u32; 4]); 8]` the elements flattened length is 6:
141/// 1. u8 (1)
142/// 2. u16 (1)
143/// 3. [u32; 4] (4)
144#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
145pub struct ElementsFlattenedLength(pub u32);
146
147impl ElementsFlattenedLength {
148    pub fn to_usize(self) -> usize {
149        assert_usize(self.0)
150    }
151}
152
153impl std::fmt::Display for ElementsFlattenedLength {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        write!(f, "{}", self.0)
156    }
157}
158
159impl Mul<SemanticLength> for ElementsFlattenedLength {
160    type Output = FlattenedLength;
161
162    /// Computes the flattened length by multiplying the semantic length
163    /// by the elements flattened length.
164    fn mul(self, rhs: SemanticLength) -> Self::Output {
165        FlattenedLength(self.0 * rhs.0)
166    }
167}
168
169impl From<FlattenedLength> for ElementsFlattenedLength {
170    /// Assumes this flattened length represents a single entry in an array or vector,
171    fn from(flattened_length: FlattenedLength) -> Self {
172        Self(flattened_length.0)
173    }
174}
175
176/// Represents the total number of fields required to represent the entirety of an array or vector.
177/// For example in the array `[(u8, u16, [u32; 4]); 8]` the flattened length is 48: 8 * (1 + 1 + 4).
178#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
179pub struct FlattenedLength(pub u32);
180
181impl FlattenedLength {
182    pub fn to_usize(self) -> usize {
183        assert_usize(self.0)
184    }
185}
186
187impl std::fmt::Display for FlattenedLength {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        write!(f, "{}", self.0)
190    }
191}
192
193impl Add for FlattenedLength {
194    type Output = FlattenedLength;
195
196    fn add(self, rhs: Self) -> Self::Output {
197        FlattenedLength(self.0 + rhs.0)
198    }
199}
200
201impl AddAssign for FlattenedLength {
202    fn add_assign(&mut self, rhs: Self) {
203        self.0 += rhs.0;
204    }
205}
206
207impl Sum for FlattenedLength {
208    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
209        iter.fold(FlattenedLength(0), |acc, x| acc + x)
210    }
211}
212
213impl Div<ElementsFlattenedLength> for FlattenedLength {
214    type Output = SemanticLength;
215
216    fn div(self, rhs: ElementsFlattenedLength) -> Self::Output {
217        if !self.0.is_multiple_of(rhs.0) {
218            panic!(
219                "Division of FlattenedLength {} by ElementsFlattenedLength {} has remainder",
220                self.0, rhs.0
221            );
222        }
223
224        SemanticLength(self.0 / rhs.0)
225    }
226}
227
228/// Converts a u32 value to usize, panicking if the conversion fails.
229fn assert_usize(value: u32) -> usize {
230    value.try_into().expect("Failed conversion from u32 to usize")
231}