brillig/
lengths.rs

1use std::{
2    iter::Sum,
3    ops::{Add, AddAssign, Div, Mul},
4};
5
6use serde::{Deserialize, Serialize};
7
8/// Represents the length of an array or vector as seen from a user's perspective.
9/// For example in the array `[(u8, u16, [u32; 4]); 8]`, the semantic length is 8.
10#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
11pub struct SemanticLength(pub u32);
12
13impl SemanticLength {
14    pub fn to_usize(self) -> usize {
15        assert_usize(self.0)
16    }
17}
18
19impl Add<SemanticLength> for SemanticLength {
20    type Output = SemanticLength;
21
22    /// Computes the sum of two semantic lengths.
23    fn add(self, rhs: SemanticLength) -> Self::Output {
24        SemanticLength(self.0 + rhs.0)
25    }
26}
27
28impl AddAssign for SemanticLength {
29    /// Adds another semantic length to this one.
30    fn add_assign(&mut self, rhs: Self) {
31        self.0 += rhs.0;
32    }
33}
34
35impl Mul<ElementTypesLength> for SemanticLength {
36    type Output = SemiFlattenedLength;
37
38    /// Computes the semi-flattened length by multiplying the semantic length
39    /// by the element types length.
40    fn mul(self, rhs: ElementTypesLength) -> Self::Output {
41        SemiFlattenedLength(self.0 * rhs.0)
42    }
43}
44
45impl std::fmt::Display for SemanticLength {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "{}", self.0)
48    }
49}
50
51/// Represents the number of types of a single element inside a vector or array, without
52/// taking into account the vector or array length.
53/// For example, in the array `[(u8, u16, [u32; 4]); 8]`, the element types length is 3:
54/// 1. u8
55/// 2. u16
56/// 3. [u32; 4]
57#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
58pub struct ElementTypesLength(pub u32);
59
60impl ElementTypesLength {
61    pub fn to_usize(self) -> usize {
62        assert_usize(self.0)
63    }
64}
65
66impl Mul<SemanticLength> for ElementTypesLength {
67    type Output = SemiFlattenedLength;
68
69    /// Computes the semi-flattened length by multiplying the semantic length
70    /// by the element types length.
71    fn mul(self, rhs: SemanticLength) -> Self::Output {
72        SemiFlattenedLength(self.0 * rhs.0)
73    }
74}
75
76impl Mul<ElementsFlattenedLength> for SemanticLength {
77    type Output = FlattenedLength;
78
79    /// Computes the flattened length by multiplying the semantic length
80    /// by the elements flattened length.
81    fn mul(self, rhs: ElementsFlattenedLength) -> Self::Output {
82        FlattenedLength(self.0 * rhs.0)
83    }
84}
85
86impl std::fmt::Display for ElementTypesLength {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        write!(f, "{}", self.0)
89    }
90}
91
92/// Represents the number of value/memory slots required to represent an array or vector.
93/// The semi-flattened length can be computed by multiplying the semantic length by
94/// the element types length.
95///
96/// For example in the array `[(u8, u16, [u32; 4]); 8]`:
97/// - The semantic length is 8
98/// - The element types length is 3
99/// - The semi-flattened length is 24 (8 * 3)
100///
101/// The reason the semi-flattened length is required, and different than the semantic length,
102/// is that in our SSA tuples are flattened so the number of value slots needed to represent an
103/// array is different than the semantic length
104///
105/// Note that this is different from the fully flattened length, which would be 8 * (1 + 1 + 4) = 48.
106#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
107#[cfg_attr(feature = "arb", derive(proptest_derive::Arbitrary))]
108pub struct SemiFlattenedLength(pub u32);
109
110impl SemiFlattenedLength {
111    pub fn to_usize(self) -> usize {
112        assert_usize(self.0)
113    }
114}
115
116impl std::fmt::Display for SemiFlattenedLength {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{}", self.0)
119    }
120}
121
122impl Div<ElementTypesLength> for SemiFlattenedLength {
123    type Output = SemanticLength;
124
125    fn div(self, rhs: ElementTypesLength) -> Self::Output {
126        if !self.0.is_multiple_of(rhs.0) {
127            panic!(
128                "Division of SemiFlattenedLength {} by ElementTypesLength {} has remainder",
129                self.0, rhs.0
130            );
131        }
132        SemanticLength(self.0 / rhs.0)
133    }
134}
135
136/// Represents the total number of fields required to represent a single entry of an array or vector.
137/// For example in the array `[(u8, u16, [u32; 4]); 8]` the elements flattened length is 6:
138/// 1. u8 (1)
139/// 2. u16 (1)
140/// 3. [u32; 4] (4)
141#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
142pub struct ElementsFlattenedLength(pub u32);
143
144impl ElementsFlattenedLength {
145    pub fn to_usize(self) -> usize {
146        assert_usize(self.0)
147    }
148}
149
150impl std::fmt::Display for ElementsFlattenedLength {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(f, "{}", self.0)
153    }
154}
155
156impl Mul<SemanticLength> for ElementsFlattenedLength {
157    type Output = FlattenedLength;
158
159    /// Computes the flattened length by multiplying the semantic length
160    /// by the elements flattened length.
161    fn mul(self, rhs: SemanticLength) -> Self::Output {
162        FlattenedLength(self.0 * rhs.0)
163    }
164}
165
166impl From<FlattenedLength> for ElementsFlattenedLength {
167    /// Assumes this flattened length represents a single entry in an array or vector,
168    fn from(flattened_length: FlattenedLength) -> Self {
169        Self(flattened_length.0)
170    }
171}
172
173/// Represents the total number of fields required to represent the entirety of an array or vector.
174/// For example in the array `[(u8, u16, [u32; 4]); 8]` the flattened length is 48: 8 * (1 + 1 + 4).
175#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
176pub struct FlattenedLength(pub u32);
177
178impl FlattenedLength {
179    pub fn to_usize(self) -> usize {
180        assert_usize(self.0)
181    }
182}
183
184impl std::fmt::Display for FlattenedLength {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        write!(f, "{}", self.0)
187    }
188}
189
190impl Add for FlattenedLength {
191    type Output = FlattenedLength;
192
193    fn add(self, rhs: Self) -> Self::Output {
194        FlattenedLength(self.0 + rhs.0)
195    }
196}
197
198impl AddAssign for FlattenedLength {
199    fn add_assign(&mut self, rhs: Self) {
200        self.0 += rhs.0;
201    }
202}
203
204impl Sum for FlattenedLength {
205    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
206        iter.fold(FlattenedLength(0), |acc, x| acc + x)
207    }
208}
209
210impl Div<ElementsFlattenedLength> for FlattenedLength {
211    type Output = SemanticLength;
212
213    fn div(self, rhs: ElementsFlattenedLength) -> Self::Output {
214        if !self.0.is_multiple_of(rhs.0) {
215            panic!(
216                "Division of FlattenedLength {} by ElementsFlattenedLength {} has remainder",
217                self.0, rhs.0
218            );
219        }
220
221        SemanticLength(self.0 / rhs.0)
222    }
223}
224
225/// Converts a u32 value to usize, panicking if the conversion fails.
226fn assert_usize(value: u32) -> usize {
227    value.try_into().expect("Failed conversion from u32 to usize")
228}