1use std::cell::RefCell;
9use std::marker::PhantomData;
10use std::rc::Rc;
11
12use proc_macro2::Span;
13use quote::quote;
14use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
15
16use crate::compile::ir::{AccessCounter, HydroNode, SharedNode};
17use crate::location::Location;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
21pub enum HandoffRefKind {
22 Singleton,
24 Optional,
26 Vec,
28}
29
30thread_local! {
34 static CAPTURED_REFS: RefCell<Option<Vec<(HydroNode, bool)>>> = const { RefCell::new(None) };
35}
36
37pub(crate) fn handoff_ref_ident(index: usize) -> syn::Ident {
39 syn::Ident::new(
40 &format!("__hydro_singleton_ref_{}", index),
41 Span::call_site(),
42 )
43}
44
45pub fn with_ref_capture(
49 f: impl FnOnce() -> crate::compile::ir::DebugExpr,
50) -> crate::compile::ir::ClosureExpr {
51 CAPTURED_REFS.with(|cell| {
52 let prev = cell.borrow_mut().replace(Vec::new());
53 assert!(
54 prev.is_none(),
55 "nested handoff reference capture scopes are not supported"
56 );
57 });
58 let expr = (f)();
59 let captured_refs = CAPTURED_REFS.with(|cell| cell.borrow_mut().take().unwrap());
60 crate::compile::ir::ClosureExpr::new(expr, captured_refs)
61}
62
63fn register_handoff_ref(
66 ir_node: &RefCell<HydroNode>,
67 is_mut: bool,
68 kind: HandoffRefKind,
69) -> syn::Ident {
70 CAPTURED_REFS.with(|cell| {
71 let mut guard = cell.borrow_mut();
72 let refs = guard.as_mut().expect(
73 "HandoffRef used inside q!() but no reference capture scope is active. \
74 This is a bug — reference capture should be set up by the operator that uses q!().",
75 );
76
77 let index = refs.len();
78 let ident = handoff_ref_ident(index);
79
80 let metadata = ir_node.borrow().metadata().clone();
81
82 if !matches!(&*ir_node.borrow(), HydroNode::Reference { .. }) {
85 let orig = ir_node.replace(HydroNode::Placeholder);
86 *ir_node.borrow_mut() = HydroNode::Reference {
87 inner: SharedNode(Rc::new(RefCell::new(orig))),
88 kind,
89 access_counter: AccessCounter::new(),
90 metadata: metadata.clone(),
91 };
92 }
93
94 let borrow: std::cell::Ref<'_, HydroNode> = ir_node.borrow();
95 let HydroNode::Reference {
96 inner,
97 access_counter,
98 ..
99 } = &*borrow
100 else {
101 unreachable!()
102 };
103
104 let group = access_counter.next_group(is_mut);
106
107 refs.push((
108 HydroNode::Reference {
109 inner: SharedNode(Rc::clone(&inner.0)),
110 kind,
111 access_counter: group,
112 metadata,
113 },
114 is_mut,
115 ));
116
117 ident
118 })
119}
120
121macro_rules! define_handoff_ref {
123 (
124 $(
125 $(#[$meta:meta])*
126 $name:ident, $is_mut:expr, $kind:expr, $output:ty
127 )+
128 ) => {
129 $(
130 $(#[$meta])*
131 pub struct $name<'a, 'slf, T, L> {
132 pub(crate) ir_node: &'slf RefCell<HydroNode>,
133 _phantom: PhantomData<(&'a T, L)>,
134 }
135
136 impl<'slf, T, L> $name<'_, 'slf, T, L> {
137 pub(crate) fn new(ir_node: &'slf RefCell<HydroNode>) -> Self {
139 Self {
140 ir_node,
141 _phantom: PhantomData,
142 }
143 }
144 }
145
146 impl<T, L> Copy for $name<'_, '_, T, L> {}
147 impl<T, L> Clone for $name<'_, '_, T, L> {
148 fn clone(&self) -> Self {
149 *self
150 }
151 }
152
153 impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for $name<'a, 'slf, T, L>
154 where
155 L: Location<'a>,
156 {
157 type O = $output;
158
159 fn to_tokens(self, _ctx: &L) -> (QuoteTokens, ()) {
160 let ident = register_handoff_ref(
161 self.ir_node,
162 $is_mut,
163 $kind,
164 );
165 (
166 QuoteTokens {
167 prelude: None,
168 expr: Some(quote!(#ident)),
169 },
170 (),
171 )
172 }
173 }
174 )+
175 };
176}
177
178#[stageleft::export(
179 SingletonRef,
180 SingletonMut,
181 OptionalRef,
182 OptionalMut,
183 StreamRef,
184 StreamMut
185)]
186define_handoff_ref!(
187 SingletonRef, false, HandoffRefKind::Singleton, &'a T
191
192 SingletonMut, true, HandoffRefKind::Singleton, &'a mut T
196
197 OptionalRef, false, HandoffRefKind::Optional, &'a Option<T>
201
202 OptionalMut, true, HandoffRefKind::Optional, &'a mut Option<T>
206
207 StreamRef, false, HandoffRefKind::Vec, &'a Vec<T>
211
212 StreamMut, true, HandoffRefKind::Vec, &'a mut Vec<T>
216);
217
218#[cfg(test)]
219#[cfg(feature = "build")]
220mod tests {
221 use stageleft::q;
222
223 use crate::compile::builder::FlowBuilder;
224 use crate::location::Location;
225
226 struct P1 {}
227
228 #[test]
230 fn singleton_by_ref_compiles() {
231 let mut flow = FlowBuilder::new();
232 let node = flow.process::<P1>();
233
234 let my_count = node
235 .source_iter(q!(0..5i32))
236 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
237 let count_ref = my_count.by_ref();
238
239 node.source_iter(q!(1..=3i32))
240 .map(q!(|x| x + *count_ref))
241 .for_each(q!(|_| {}));
242
243 my_count.into_stream().for_each(q!(|_| {}));
244 let _built = flow.finalize();
245 }
246
247 #[test]
249 fn singleton_by_ref_non_copy() {
250 let mut flow = FlowBuilder::new();
251 let node = flow.process::<P1>();
252
253 let my_vec = node.source_iter(q!(0..5i32)).fold(
254 q!(|| Vec::<i32>::new()),
255 q!(|acc: &mut Vec<i32>, x| acc.push(x)),
256 );
257 let vec_ref = my_vec.by_ref();
258
259 node.source_iter(q!(1..=3i32))
260 .map(q!(|x| x + vec_ref.len() as i32))
261 .for_each(q!(|_| {}));
262
263 my_vec.into_stream().for_each(q!(|_| {}));
264 let _built = flow.finalize();
265 }
266
267 #[test]
269 fn singleton_by_ref_filter() {
270 let mut flow = FlowBuilder::new();
271 let node = flow.process::<P1>();
272
273 let threshold = node
274 .source_iter(q!(0..5i32))
275 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
276 let threshold_ref = threshold.by_ref();
277
278 node.source_iter(q!(1..=10i32))
279 .filter(q!(|x| *x > *threshold_ref))
280 .for_each(q!(|_| {}));
281
282 threshold.into_stream().for_each(q!(|_| {}));
283 let _built = flow.finalize();
284 }
285
286 #[test]
288 fn singleton_by_ref_flat_map() {
289 let mut flow = FlowBuilder::new();
290 let node = flow.process::<P1>();
291
292 let count = node
293 .source_iter(q!(0..3i32))
294 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
295 let count_ref = count.by_ref();
296
297 node.source_iter(q!(1..=2i32))
298 .flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
299 .for_each(q!(|_| {}));
300
301 count.into_stream().for_each(q!(|_| {}));
302 let _built = flow.finalize();
303 }
304
305 #[test]
307 fn singleton_by_ref_inspect() {
308 let mut flow = FlowBuilder::new();
309 let node = flow.process::<P1>();
310
311 let count = node
312 .source_iter(q!(0..5i32))
313 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
314 let count_ref = count.by_ref();
315
316 node.source_iter(q!(1..=3i32))
317 .inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
318 .for_each(q!(|_| {}));
319
320 count.into_stream().for_each(q!(|_| {}));
321 let _built = flow.finalize();
322 }
323
324 #[test]
326 fn singleton_by_ref_partition() {
327 let mut flow = FlowBuilder::new();
328 let node = flow.process::<P1>();
329
330 let threshold = node
331 .source_iter(q!(0..5i32))
332 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
333 let threshold_ref = threshold.by_ref();
334
335 let (above, below) = node
336 .source_iter(q!(1..=10i32))
337 .partition(q!(|x| *x > *threshold_ref));
338
339 above.for_each(q!(|_| {}));
340 below.for_each(q!(|_| {}));
341 threshold.into_stream().for_each(q!(|_| {}));
342 let _built = flow.finalize();
343 }
344
345 #[test]
347 fn singleton_by_ref_partition_with_downstream_ops() {
348 let mut flow = FlowBuilder::new();
349 let node = flow.process::<P1>();
350
351 let threshold = node
352 .source_iter(q!(0..5i32))
353 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
354 let threshold_ref = threshold.by_ref();
355
356 let (above, below) = node
357 .source_iter(q!(1..=10i32))
358 .partition(q!(|x| *x > *threshold_ref));
359
360 above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
361 below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
362 threshold.into_stream().for_each(q!(|_| {}));
363 let _built = flow.finalize();
364 }
365
366 #[test]
368 fn singleton_by_mut_compiles() {
369 let mut flow = FlowBuilder::new();
370 let node = flow.process::<P1>();
371
372 let my_count = node
373 .source_iter(q!(0..5i32))
374 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
375 let count_mut = my_count.by_mut();
376
377 node.source_iter(q!(1..=3i32))
378 .map(q!(|x| {
379 *count_mut += x;
380 x
381 }))
382 .for_each(q!(|_| {}));
383
384 my_count.into_stream().for_each(q!(|_| {}));
385 let _built = flow.finalize();
386 }
387
388 #[test]
390 fn optional_by_ref_compiles() {
391 let mut flow = FlowBuilder::new();
392 let node = flow.process::<P1>();
393
394 let my_opt = node.source_iter(q!(0..5i32)).reduce(q!(|a, b| *a += b));
395 let opt_ref = my_opt.by_ref();
396
397 node.source_iter(q!(1..=3i32))
398 .map(q!(|x| x + opt_ref.unwrap_or(0)))
399 .for_each(q!(|_| {}));
400
401 my_opt.into_stream().for_each(q!(|_| {}));
402 let _built = flow.finalize();
403 }
404
405 #[test]
407 fn stream_by_ref_compiles() {
408 let mut flow = FlowBuilder::new();
409 let node = flow.process::<P1>();
410
411 let my_stream = node.source_iter(q!(0..5i32));
412 let stream_ref = my_stream.by_ref();
413
414 node.source_iter(q!(1..=3i32))
415 .map(q!(|x| x + stream_ref.len() as i32))
416 .for_each(q!(|_| {}));
417
418 my_stream.for_each(q!(|_| {}));
419 let _built = flow.finalize();
420 }
421
422 #[test]
424 fn singleton_by_mut_filter() {
425 let mut flow = FlowBuilder::new();
426 let node = flow.process::<P1>();
427
428 let my_count = node
429 .source_iter(q!(0..5i32))
430 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
431 let count_mut = my_count.by_mut();
432
433 node.source_iter(q!(1..=3i32))
434 .filter(q!(|x| {
435 *count_mut += *x;
436 *count_mut > 0
437 }))
438 .for_each(q!(|_| {}));
439
440 my_count.into_stream().for_each(q!(|_| {}));
441 let _built = flow.finalize();
442 }
443
444 #[test]
446 fn singleton_by_mut_flat_map() {
447 let mut flow = FlowBuilder::new();
448 let node = flow.process::<P1>();
449
450 let my_count = node
451 .source_iter(q!(0..5i32))
452 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
453 let count_mut = my_count.by_mut();
454
455 node.source_iter(q!(1..=3i32))
456 .flat_map_ordered(q!(|x| {
457 *count_mut += x;
458 vec![*count_mut]
459 }))
460 .for_each(q!(|_| {}));
461
462 my_count.into_stream().for_each(q!(|_| {}));
463 let _built = flow.finalize();
464 }
465
466 #[test]
468 fn singleton_by_mut_filter_map() {
469 let mut flow = FlowBuilder::new();
470 let node = flow.process::<P1>();
471
472 let my_count = node
473 .source_iter(q!(0..5i32))
474 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
475 let count_mut = my_count.by_mut();
476
477 node.source_iter(q!(1..=3i32))
478 .filter_map(q!(|x| {
479 *count_mut += x;
480 Some(*count_mut)
481 }))
482 .for_each(q!(|_| {}));
483
484 my_count.into_stream().for_each(q!(|_| {}));
485 let _built = flow.finalize();
486 }
487
488 #[test]
490 fn singleton_by_mut_inspect() {
491 let mut flow = FlowBuilder::new();
492 let node = flow.process::<P1>();
493
494 let my_count = node
495 .source_iter(q!(0..5i32))
496 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
497 let count_mut = my_count.by_mut();
498
499 node.source_iter(q!(1..=3i32))
500 .inspect(q!(|x| {
501 *count_mut += *x;
502 }))
503 .for_each(q!(|_| {}));
504
505 my_count.into_stream().for_each(q!(|_| {}));
506 let _built = flow.finalize();
507 }
508
509 #[test]
511 fn singleton_by_ref_for_each() {
512 let mut flow = FlowBuilder::new();
513 let node = flow.process::<P1>();
514
515 let my_count = node
516 .source_iter(q!(0..5i32))
517 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
518 let count_ref = my_count.by_ref();
519
520 node.source_iter(q!(1..=3i32))
521 .for_each(q!(|x| println!("{}", x + *count_ref)));
522
523 my_count.into_stream().for_each(q!(|_| {}));
524 let _built = flow.finalize();
525 }
526
527 #[test]
529 fn singleton_by_mut_for_each() {
530 let mut flow = FlowBuilder::new();
531 let node = flow.process::<P1>();
532
533 let my_count = node
534 .source_iter(q!(0..5i32))
535 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
536 let count_mut = my_count.by_mut();
537
538 node.source_iter(q!(1..=3i32)).for_each(q!(|x| {
539 *count_mut += x;
540 }));
541
542 my_count.into_stream().for_each(q!(|_| {}));
543 let _built = flow.finalize();
544 }
545}