1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
115#[non_exhaustive]
116pub struct OperatorWriteOutput {
117 pub write_prologue: TokenStream,
121 pub write_prologue_after: TokenStream,
124 pub write_iterator: TokenStream,
131 pub write_iterator_after: TokenStream,
133}
134
135pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
137pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
139pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
141
142pub fn identity_write_iterator_fn(
145 &WriteContextArgs {
146 root,
147 op_span,
148 ident,
149 inputs,
150 outputs,
151 is_pull,
152 op_inst:
153 OperatorInstance {
154 generics: OpInstGenerics { type_args, .. },
155 ..
156 },
157 ..
158 }: &WriteContextArgs,
159) -> TokenStream {
160 let generic_type = type_args
161 .first()
162 .map(quote::ToTokens::to_token_stream)
163 .unwrap_or(quote_spanned!(op_span=> _));
164
165 if is_pull {
166 let input = &inputs[0];
167 quote_spanned! {op_span=>
168 let #ident = {
169 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
170 where
171 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
172 {
173 pull
174 }
175 check_input::<_, #generic_type>(#input)
176 };
177 }
178 } else {
179 let output = &outputs[0];
180 quote_spanned! {op_span=>
181 let #ident = {
182 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
183 where
184 Psh: #root::dfir_pipes::push::Push<Item, ()>,
185 {
186 push
187 }
188 check_output::<_, #generic_type>(#output)
189 };
190 }
191 }
192}
193
194pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
196 let write_iterator = identity_write_iterator_fn(write_context_args);
197 Ok(OperatorWriteOutput {
198 write_iterator,
199 ..Default::default()
200 })
201};
202
203pub fn null_write_iterator_fn(
206 &WriteContextArgs {
207 root,
208 op_span,
209 ident,
210 inputs,
211 outputs,
212 is_pull,
213 op_inst:
214 OperatorInstance {
215 generics: OpInstGenerics { type_args, .. },
216 ..
217 },
218 ..
219 }: &WriteContextArgs,
220) -> TokenStream {
221 let default_type = parse_quote_spanned! {op_span=> _};
222 let iter_type = type_args.first().unwrap_or(&default_type);
223
224 if is_pull {
225 quote_spanned! {op_span=>
226 let #ident = #root::dfir_pipes::pull::poll_fn({
227 #(
228 let mut #inputs = ::std::boxed::Box::pin(#inputs);
229 )*
230 move |_cx| {
231 #(
235 let #inputs = #root::dfir_pipes::pull::Pull::pull(
236 ::std::pin::Pin::as_mut(&mut #inputs),
237 <_ as #root::dfir_pipes::Context>::from_task(_cx),
238 );
239 )*
240 #(
241 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
242 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
243 }
244 )*
245 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
246 }
247 });
248 }
249 } else {
250 quote_spanned! {op_span=>
251 #[allow(clippy::let_unit_value)]
252 let _ = (#(#outputs),*);
253 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
254 }
255 }
256}
257
258pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
261 let write_iterator = null_write_iterator_fn(write_context_args);
262 Ok(OperatorWriteOutput {
263 write_iterator,
264 ..Default::default()
265 })
266};
267
268macro_rules! declare_ops {
269 ( $( $mod:ident :: $op:ident, )* ) => {
270 $( pub(crate) mod $mod; )*
271 pub const OPERATORS: &[OperatorConstraints] = &[
273 $( $mod :: $op, )*
274 ];
275 };
276}
277declare_ops![
278 all_iterations::ALL_ITERATIONS,
279 all_once::ALL_ONCE,
280 anti_join::ANTI_JOIN,
281 assert::ASSERT,
282 assert_eq::ASSERT_EQ,
283 batch::BATCH,
284 chain::CHAIN,
285 chain_first_n::CHAIN_FIRST_N,
286 _counter::_COUNTER,
287 cross_join::CROSS_JOIN,
288 cross_join_multiset::CROSS_JOIN_MULTISET,
289 cross_singleton::CROSS_SINGLETON,
290 demux_enum::DEMUX_ENUM,
291 dest_file::DEST_FILE,
292 dest_sink::DEST_SINK,
293 dest_sink_serde::DEST_SINK_SERDE,
294 difference::DIFFERENCE,
295 enumerate::ENUMERATE,
296 filter::FILTER,
297 filter_map::FILTER_MAP,
298 flat_map::FLAT_MAP,
299 flat_map_stream::FLAT_MAP_STREAM,
300 flatten::FLATTEN,
301 flatten_stream::FLATTEN_STREAM,
302 fold::FOLD,
303 fold_no_replay::FOLD_NO_REPLAY,
304 for_each::FOR_EACH,
305 identity::IDENTITY,
306 initialize::INITIALIZE,
307 inspect::INSPECT,
308 join::JOIN,
309 join_fused::JOIN_FUSED,
310 join_fused_lhs::JOIN_FUSED_LHS,
311 join_fused_rhs::JOIN_FUSED_RHS,
312 join_multiset::JOIN_MULTISET,
313 fold_keyed::FOLD_KEYED,
314 reduce_keyed::REDUCE_KEYED,
315 repeat_n::REPEAT_N,
316 lattice_bimorphism::LATTICE_BIMORPHISM,
318 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
319 lattice_fold::LATTICE_FOLD,
320 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
321 lattice_reduce::LATTICE_REDUCE,
322 map::MAP,
323 union::UNION,
324 multiset_delta::MULTISET_DELTA,
325 next_iteration::NEXT_ITERATION,
326 next_stratum::NEXT_STRATUM,
327 defer_signal::DEFER_SIGNAL,
328 defer_tick::DEFER_TICK,
329 defer_tick_lazy::DEFER_TICK_LAZY,
330 null::NULL,
331 partition::PARTITION,
332 persist::PERSIST,
333 persist_mut::PERSIST_MUT,
334 persist_mut_keyed::PERSIST_MUT_KEYED,
335 prefix::PREFIX,
336 resolve_futures::RESOLVE_FUTURES,
337 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
338 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
339 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
340 reduce::REDUCE,
341 reduce_no_replay::REDUCE_NO_REPLAY,
342 scan::SCAN,
343 spin::SPIN,
344 sort::SORT,
345 sort_by_key::SORT_BY_KEY,
346 source_file::SOURCE_FILE,
347 source_interval::SOURCE_INTERVAL,
348 source_iter::SOURCE_ITER,
349 source_json::SOURCE_JSON,
350 source_stdin::SOURCE_STDIN,
351 source_stream::SOURCE_STREAM,
352 source_stream_serde::SOURCE_STREAM_SERDE,
353 state::STATE,
354 state_by::STATE_BY,
355 tee::TEE,
356 unique::UNIQUE,
357 unzip::UNZIP,
358 zip::ZIP,
359 zip_longest::ZIP_LONGEST,
360];
361
362pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
364 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
365 OnceLock::new();
366 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
367}
368pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
370 if let GraphNode::Operator(operator) = node {
371 find_op_op_constraints(operator)
372 } else {
373 None
374 }
375}
376pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
378 let name = &*operator.name_string();
379 operator_lookup().get(name).copied()
380}
381
382#[derive(Clone)]
384pub struct WriteContextArgs<'a> {
385 pub root: &'a TokenStream,
387 pub context: &'a Ident,
390 pub df_ident: &'a Ident,
394 pub subgraph_id: GraphSubgraphId,
396 pub node_id: GraphNodeId,
398 pub loop_id: Option<GraphLoopId>,
400 pub op_span: Span,
402 pub op_tag: Option<String>,
404 pub work_fn: &'a Ident,
406 pub work_fn_async: &'a Ident,
408
409 pub ident: &'a Ident,
411 pub is_pull: bool,
413 pub inputs: &'a [Ident],
415 pub outputs: &'a [Ident],
417 pub singleton_output_ident: &'a Ident,
419
420 pub op_name: &'static str,
422 pub op_inst: &'a OperatorInstance,
424 pub arguments: &'a Punctuated<Expr, Token![,]>,
430 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
432}
433impl WriteContextArgs<'_> {
434 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
440 Ident::new(
441 &format!(
442 "sg_{:?}_node_{:?}_{}",
443 self.subgraph_id.data(),
444 self.node_id.data(),
445 suffix.as_ref(),
446 ),
447 self.op_span,
448 )
449 }
450
451 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
454 let root = self.root;
455 let variant =
456 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
457 Some(quote_spanned! {self.op_span=>
458 #root::scheduled::graph::StateLifespan::#variant
459 })
460 }
461
462 pub fn persistence_args_disallow_mutable<const N: usize>(
464 &self,
465 diagnostics: &mut Diagnostics,
466 ) -> [Persistence; N] {
467 let len = self.op_inst.generics.persistence_args.len();
468 if 0 != len && 1 != len && N != len {
469 diagnostics.push(Diagnostic::spanned(
470 self.op_span,
471 Level::Error,
472 format!(
473 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
474 self.op_name, N
475 ),
476 ));
477 }
478
479 let default_persistence = if self.loop_id.is_some() {
480 Persistence::None
481 } else {
482 Persistence::Tick
483 };
484 let mut out = [default_persistence; N];
485 self.op_inst
486 .generics
487 .persistence_args
488 .iter()
489 .copied()
490 .cycle() .take(N)
492 .enumerate()
493 .filter(|&(_i, p)| {
494 if p == Persistence::Mutable {
495 diagnostics.push(Diagnostic::spanned(
496 self.op_span,
497 Level::Error,
498 format!(
499 "An implementation of `'{}` does not exist",
500 p.to_str_lowercase()
501 ),
502 ));
503 false
504 } else {
505 true
506 }
507 })
508 .for_each(|(i, p)| {
509 out[i] = p;
510 });
511 out
512 }
513}
514
515pub trait RangeTrait<T>: Send + Sync + Debug
517where
518 T: ?Sized,
519{
520 fn start_bound(&self) -> Bound<&T>;
522 fn end_bound(&self) -> Bound<&T>;
524 fn contains(&self, item: &T) -> bool
526 where
527 T: PartialOrd<T>;
528
529 fn human_string(&self) -> String
531 where
532 T: Display + PartialEq,
533 {
534 match (self.start_bound(), self.end_bound()) {
535 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
536
537 (Bound::Included(n), Bound::Included(x)) if n == x => {
538 format!("exactly {}", n)
539 }
540 (Bound::Included(n), Bound::Included(x)) => {
541 format!("at least {} and at most {}", n, x)
542 }
543 (Bound::Included(n), Bound::Excluded(x)) => {
544 format!("at least {} and less than {}", n, x)
545 }
546 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
547 (Bound::Excluded(n), Bound::Included(x)) => {
548 format!("more than {} and at most {}", n, x)
549 }
550 (Bound::Excluded(n), Bound::Excluded(x)) => {
551 format!("more than {} and less than {}", n, x)
552 }
553 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
554 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
555 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
556 }
557 }
558}
559
560impl<R, T> RangeTrait<T> for R
561where
562 R: RangeBounds<T> + Send + Sync + Debug,
563{
564 fn start_bound(&self) -> Bound<&T> {
565 self.start_bound()
566 }
567
568 fn end_bound(&self) -> Bound<&T> {
569 self.end_bound()
570 }
571
572 fn contains(&self, item: &T) -> bool
573 where
574 T: PartialOrd<T>,
575 {
576 self.contains(item)
577 }
578}
579
580#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
582pub enum Persistence {
583 None,
585 Loop,
587 Tick,
589 Static,
591 Mutable,
593}
594impl Persistence {
595 pub fn as_state_lifespan_variant(
597 self,
598 subgraph_id: GraphSubgraphId,
599 loop_id: Option<GraphLoopId>,
600 span: Span,
601 ) -> Option<TokenStream> {
602 match self {
603 Persistence::None => {
604 let sg_ident = subgraph_id.as_ident(span);
605 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
606 }
607 Persistence::Loop => {
608 let loop_ident = loop_id
609 .expect("`Persistence::Loop` outside of a loop context.")
610 .as_ident(span);
611 Some(quote_spanned!(span=> Loop(#loop_ident)))
612 }
613 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
614 Persistence::Static => None,
615 Persistence::Mutable => None,
616 }
617 }
618
619 pub fn to_str_lowercase(self) -> &'static str {
621 match self {
622 Persistence::None => "none",
623 Persistence::Tick => "tick",
624 Persistence::Loop => "loop",
625 Persistence::Static => "static",
626 Persistence::Mutable => "mutable",
627 }
628 }
629}
630
631fn make_missing_runtime_msg(op_name: &str) -> Literal {
633 Literal::string(&format!(
634 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
635 op_name
636 ))
637}
638
639#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
641pub enum OperatorCategory {
642 Map,
644 Filter,
646 Flatten,
648 Fold,
650 KeyedFold,
652 LatticeFold,
654 Persistence,
656 MultiIn,
658 MultiOut,
660 Source,
662 Sink,
664 Control,
666 CompilerFusionOperator,
668 Windowing,
670 Unwindowing,
672}
673impl OperatorCategory {
674 pub fn name(self) -> &'static str {
676 self.get_variant_docs().split_once(":").unwrap().0
677 }
678 pub fn description(self) -> &'static str {
680 self.get_variant_docs().split_once(":").unwrap().1
681 }
682}
683
684#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
686pub enum FloType {
687 Source,
689 Windowing,
691 Unwindowing,
693 NextIteration,
695}