1use acir_field::AcirField;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
13#[serde(untagged)]
14pub enum ForeignCallParam<F> {
15 Single(F),
17 Array(Vec<F>),
19}
20
21impl<F> From<F> for ForeignCallParam<F> {
22 fn from(value: F) -> Self {
23 ForeignCallParam::Single(value)
24 }
25}
26
27impl<F> From<Vec<F>> for ForeignCallParam<F> {
28 fn from(values: Vec<F>) -> Self {
29 ForeignCallParam::Array(values)
30 }
31}
32
33impl<F: AcirField> ForeignCallParam<F> {
34 pub fn fields(&self) -> Vec<F> {
39 match self {
40 ForeignCallParam::Single(value) => vec![*value],
41 ForeignCallParam::Array(values) => values.to_vec(),
42 }
43 }
44
45 pub fn unwrap_field(&self) -> F {
47 match self {
48 ForeignCallParam::Single(value) => *value,
49 ForeignCallParam::Array(_) => panic!("Expected single value, found array"),
50 }
51 }
52}
53
54#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Default)]
60pub struct ForeignCallResult<F> {
61 pub values: Vec<ForeignCallParam<F>>,
65}
66
67impl<F> From<F> for ForeignCallResult<F> {
71 fn from(value: F) -> Self {
72 ForeignCallResult { values: vec![value.into()] }
73 }
74}
75
76impl<F> From<Vec<F>> for ForeignCallResult<F> {
96 fn from(values: Vec<F>) -> Self {
97 ForeignCallResult { values: vec![values.into()] }
98 }
99}
100
101impl<F> From<Vec<ForeignCallParam<F>>> for ForeignCallResult<F> {
106 fn from(values: Vec<ForeignCallParam<F>>) -> Self {
107 ForeignCallResult { values }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use acir_field::FieldElement;
115
116 #[test]
117 fn test_foreign_call_param_from_single() {
118 let value = FieldElement::from(42u128);
119 let param = ForeignCallParam::from(value);
120
121 assert_eq!(param, ForeignCallParam::Single(value));
122 assert_eq!(param.fields(), vec![value]);
123 assert_eq!(param.unwrap_field(), value);
124 }
125
126 #[test]
127 fn test_foreign_call_param_from_array() {
128 let values =
129 vec![FieldElement::from(1u128), FieldElement::from(2u128), FieldElement::from(3u128)];
130 let param = ForeignCallParam::from(values.clone());
131
132 assert_eq!(param, ForeignCallParam::Array(values.clone()));
133 assert_eq!(param.fields(), values);
134 }
135
136 #[test]
137 fn test_foreign_call_param_array_roundtrip() {
138 let original = vec![
140 FieldElement::from(10u128),
141 FieldElement::from(20u128),
142 FieldElement::from(30u128),
143 ];
144
145 let param: ForeignCallParam<FieldElement> = original.clone().into();
147 let roundtrip = param.fields();
148
149 assert_eq!(roundtrip, original);
150 }
151
152 #[test]
153 fn test_foreign_call_param_single_to_array() {
154 let value = FieldElement::from(42u128);
157 let param = ForeignCallParam::Single(value);
158
159 assert_eq!(param.fields(), vec![value]);
161 }
162
163 #[test]
164 #[should_panic(expected = "Expected single value, found array")]
165 fn test_foreign_call_param_unwrap_field_panics_on_array() {
166 let param =
167 ForeignCallParam::Array(vec![FieldElement::from(1u128), FieldElement::from(2u128)]);
168
169 param.unwrap_field();
171 }
172
173 #[test]
174 fn test_foreign_call_result_from_single_value() {
175 let value = FieldElement::from(42u128);
176 let result = ForeignCallResult::from(value);
177
178 assert_eq!(result.values.len(), 1);
179 assert_eq!(result.values[0], ForeignCallParam::Single(value));
180 }
181
182 #[test]
183 fn test_foreign_call_result_from_vec_creates_single_array_output() {
184 let values =
185 vec![FieldElement::from(1u128), FieldElement::from(2u128), FieldElement::from(3u128)];
186 let result = ForeignCallResult::from(values.clone());
187
188 assert_eq!(result.values.len(), 1);
190 assert_eq!(result.values[0], ForeignCallParam::Array(values));
191 }
192
193 #[test]
194 fn test_foreign_call_result_from_params_creates_multiple_outputs() {
195 let params = vec![
196 ForeignCallParam::Single(FieldElement::from(1u128)),
197 ForeignCallParam::Single(FieldElement::from(2u128)),
198 ForeignCallParam::Single(FieldElement::from(3u128)),
199 ];
200 let result = ForeignCallResult::from(params.clone());
201
202 assert_eq!(result.values.len(), 3);
204 assert_eq!(result.values, params);
205 }
206}