1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn =
91 fn(&WriteContextArgs<'_>, &mut Vec<Diagnostic>) -> Result<OperatorWriteOutput, ()>;
92
93impl Debug for OperatorConstraints {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("OperatorConstraints")
96 .field("name", &self.name)
97 .field("hard_range_inn", &self.hard_range_inn)
98 .field("soft_range_inn", &self.soft_range_inn)
99 .field("hard_range_out", &self.hard_range_out)
100 .field("soft_range_out", &self.soft_range_out)
101 .field("num_args", &self.num_args)
102 .field("persistence_args", &self.persistence_args)
103 .field("type_args", &self.type_args)
104 .field("is_external_input", &self.is_external_input)
105 .field("ports_inn", &self.ports_inn)
106 .field("ports_out", &self.ports_out)
107 .finish()
111 }
112}
113
114#[derive(Default)]
116#[non_exhaustive]
117pub struct OperatorWriteOutput {
118 pub write_prologue: TokenStream,
122 pub write_prologue_after: TokenStream,
125 pub write_iterator: TokenStream,
132 pub write_iterator_after: TokenStream,
134}
135
136pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
138pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
140pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
142
143pub fn identity_write_iterator_fn(
146 &WriteContextArgs {
147 root,
148 op_span,
149 ident,
150 inputs,
151 outputs,
152 is_pull,
153 op_inst:
154 OperatorInstance {
155 generics: OpInstGenerics { type_args, .. },
156 ..
157 },
158 ..
159 }: &WriteContextArgs,
160) -> TokenStream {
161 let generic_type = type_args
162 .first()
163 .map(quote::ToTokens::to_token_stream)
164 .unwrap_or(quote_spanned!(op_span=> _));
165
166 if is_pull {
167 let input = &inputs[0];
168 quote_spanned! {op_span=>
169 let #ident = {
170 fn check_input<Iter: ::std::iter::Iterator<Item = Item>, Item>(iter: Iter) -> impl ::std::iter::Iterator<Item = Item> { iter }
171 check_input::<_, #generic_type>(#input)
172 };
173 }
174 } else {
175 let output = &outputs[0];
176 quote_spanned! {op_span=>
177 let #ident = {
178 fn check_output<Push: #root::pusherator::Pusherator<Item = Item>, Item>(push: Push) -> impl #root::pusherator::Pusherator<Item = Item> { push }
179 check_output::<_, #generic_type>(#output)
180 };
181 }
182 }
183}
184
185pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
187 let write_iterator = identity_write_iterator_fn(write_context_args);
188 Ok(OperatorWriteOutput {
189 write_iterator,
190 ..Default::default()
191 })
192};
193
194pub fn null_write_iterator_fn(
197 &WriteContextArgs {
198 root,
199 op_span,
200 ident,
201 inputs,
202 outputs,
203 is_pull,
204 op_inst:
205 OperatorInstance {
206 generics: OpInstGenerics { type_args, .. },
207 ..
208 },
209 ..
210 }: &WriteContextArgs,
211) -> TokenStream {
212 let default_type = parse_quote_spanned! {op_span=> _};
213 let iter_type = type_args.first().unwrap_or(&default_type);
214
215 if is_pull {
216 quote_spanned! {op_span=>
217 #(
218 #inputs.for_each(std::mem::drop);
219 )*
220 let #ident = std::iter::empty::<#iter_type>();
221 }
222 } else {
223 quote_spanned! {op_span=>
224 #[allow(clippy::let_unit_value)]
225 let _ = (#(#outputs),*);
226 let #ident = #root::pusherator::null::Null::<#iter_type>::new();
227 }
228 }
229}
230
231pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
234 let write_iterator = null_write_iterator_fn(write_context_args);
235 Ok(OperatorWriteOutput {
236 write_iterator,
237 ..Default::default()
238 })
239};
240
241macro_rules! declare_ops {
242 ( $( $mod:ident :: $op:ident, )* ) => {
243 $( pub(crate) mod $mod; )*
244 pub const OPERATORS: &[OperatorConstraints] = &[
246 $( $mod :: $op, )*
247 ];
248 };
249}
250declare_ops![
251 all_iterations::ALL_ITERATIONS,
252 all_once::ALL_ONCE,
253 anti_join::ANTI_JOIN,
254 anti_join_multiset::ANTI_JOIN_MULTISET,
255 assert::ASSERT,
256 assert_eq::ASSERT_EQ,
257 batch::BATCH,
258 chain::CHAIN,
259 _counter::_COUNTER,
260 cross_join::CROSS_JOIN,
261 cross_join_multiset::CROSS_JOIN_MULTISET,
262 cross_singleton::CROSS_SINGLETON,
263 demux::DEMUX,
264 demux_enum::DEMUX_ENUM,
265 dest_file::DEST_FILE,
266 dest_sink::DEST_SINK,
267 dest_sink_serde::DEST_SINK_SERDE,
268 difference::DIFFERENCE,
269 difference_multiset::DIFFERENCE_MULTISET,
270 enumerate::ENUMERATE,
271 filter::FILTER,
272 filter_map::FILTER_MAP,
273 flat_map::FLAT_MAP,
274 flatten::FLATTEN,
275 fold::FOLD,
276 for_each::FOR_EACH,
277 identity::IDENTITY,
278 initialize::INITIALIZE,
279 inspect::INSPECT,
280 join::JOIN,
281 join_fused::JOIN_FUSED,
282 join_fused_lhs::JOIN_FUSED_LHS,
283 join_fused_rhs::JOIN_FUSED_RHS,
284 join_multiset::JOIN_MULTISET,
285 fold_keyed::FOLD_KEYED,
286 reduce_keyed::REDUCE_KEYED,
287 repeat_n::REPEAT_N,
288 lattice_bimorphism::LATTICE_BIMORPHISM,
290 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
291 lattice_fold::LATTICE_FOLD,
292 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
293 lattice_reduce::LATTICE_REDUCE,
294 map::MAP,
295 union::UNION,
296 multiset_delta::MULTISET_DELTA,
297 next_iteration::NEXT_ITERATION,
298 next_stratum::NEXT_STRATUM,
299 defer_signal::DEFER_SIGNAL,
300 defer_tick::DEFER_TICK,
301 defer_tick_lazy::DEFER_TICK_LAZY,
302 null::NULL,
303 partition::PARTITION,
304 persist::PERSIST,
305 persist_mut::PERSIST_MUT,
306 persist_mut_keyed::PERSIST_MUT_KEYED,
307 prefix::PREFIX,
308 resolve_futures::RESOLVE_FUTURES,
309 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
310 py_udf::PY_UDF,
311 reduce::REDUCE,
312 scan::SCAN,
313 spin::SPIN,
314 sort::SORT,
315 sort_by_key::SORT_BY_KEY,
316 source_file::SOURCE_FILE,
317 source_interval::SOURCE_INTERVAL,
318 source_iter::SOURCE_ITER,
319 source_json::SOURCE_JSON,
320 source_stdin::SOURCE_STDIN,
321 source_stream::SOURCE_STREAM,
322 source_stream_serde::SOURCE_STREAM_SERDE,
323 state::STATE,
324 state_by::STATE_BY,
325 tee::TEE,
326 unique::UNIQUE,
327 unzip::UNZIP,
328 zip::ZIP,
329 zip_longest::ZIP_LONGEST,
330];
331
332pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
334 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
335 OnceLock::new();
336 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
337}
338pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
340 if let GraphNode::Operator(operator) = node {
341 find_op_op_constraints(operator)
342 } else {
343 None
344 }
345}
346pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
348 let name = &*operator.name_string();
349 operator_lookup().get(name).copied()
350}
351
352#[derive(Clone)]
354pub struct WriteContextArgs<'a> {
355 pub root: &'a TokenStream,
357 pub context: &'a Ident,
360 pub df_ident: &'a Ident,
364 pub subgraph_id: GraphSubgraphId,
366 pub node_id: GraphNodeId,
368 pub loop_id: Option<GraphLoopId>,
370 pub op_span: Span,
372 pub op_tag: Option<String>,
374 pub work_fn: &'a Ident,
376
377 pub ident: &'a Ident,
379 pub is_pull: bool,
381 pub inputs: &'a [Ident],
383 pub outputs: &'a [Ident],
385 pub singleton_output_ident: &'a Ident,
387
388 pub op_name: &'static str,
390 pub op_inst: &'a OperatorInstance,
392 pub arguments: &'a Punctuated<Expr, Token![,]>,
398 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
400}
401impl WriteContextArgs<'_> {
402 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
408 Ident::new(
409 &format!(
410 "sg_{:?}_node_{:?}_{}",
411 self.subgraph_id.data(),
412 self.node_id.data(),
413 suffix.as_ref(),
414 ),
415 self.op_span,
416 )
417 }
418
419 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
422 let root = self.root;
423 let variant =
424 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
425 Some(quote_spanned! {self.op_span=>
426 #root::scheduled::graph::StateLifespan::#variant
427 })
428 }
429
430 pub fn persistence_args_disallow_mutable<const N: usize>(
432 &self,
433 diagnostics: &mut Vec<Diagnostic>,
434 ) -> [Persistence; N] {
435 let len = self.op_inst.generics.persistence_args.len();
436 if 0 != len && 1 != len && N != len {
437 diagnostics.push(Diagnostic::spanned(
438 self.op_span,
439 Level::Error,
440 format!(
441 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
442 self.op_name, N
443 ),
444 ));
445 }
446
447 let default_persistence = if self.loop_id.is_some() {
448 Persistence::None
449 } else {
450 Persistence::Tick
451 };
452 let mut out = [default_persistence; N];
453 self.op_inst
454 .generics
455 .persistence_args
456 .iter()
457 .copied()
458 .cycle() .take(N)
460 .enumerate()
461 .filter(|&(_i, p)| {
462 if p == Persistence::Mutable {
463 diagnostics.push(Diagnostic::spanned(
464 self.op_span,
465 Level::Error,
466 format!(
467 "An implementation of `'{}` does not exist",
468 p.to_str_lowercase()
469 ),
470 ));
471 false
472 } else {
473 true
474 }
475 })
476 .for_each(|(i, p)| {
477 out[i] = p;
478 });
479 out
480 }
481}
482
483pub trait RangeTrait<T>: Send + Sync + Debug
485where
486 T: ?Sized,
487{
488 fn start_bound(&self) -> Bound<&T>;
490 fn end_bound(&self) -> Bound<&T>;
492 fn contains(&self, item: &T) -> bool
494 where
495 T: PartialOrd<T>;
496
497 fn human_string(&self) -> String
499 where
500 T: Display + PartialEq,
501 {
502 match (self.start_bound(), self.end_bound()) {
503 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
504
505 (Bound::Included(n), Bound::Included(x)) if n == x => {
506 format!("exactly {}", n)
507 }
508 (Bound::Included(n), Bound::Included(x)) => {
509 format!("at least {} and at most {}", n, x)
510 }
511 (Bound::Included(n), Bound::Excluded(x)) => {
512 format!("at least {} and less than {}", n, x)
513 }
514 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
515 (Bound::Excluded(n), Bound::Included(x)) => {
516 format!("more than {} and at most {}", n, x)
517 }
518 (Bound::Excluded(n), Bound::Excluded(x)) => {
519 format!("more than {} and less than {}", n, x)
520 }
521 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
522 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
523 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
524 }
525 }
526}
527
528impl<R, T> RangeTrait<T> for R
529where
530 R: RangeBounds<T> + Send + Sync + Debug,
531{
532 fn start_bound(&self) -> Bound<&T> {
533 self.start_bound()
534 }
535
536 fn end_bound(&self) -> Bound<&T> {
537 self.end_bound()
538 }
539
540 fn contains(&self, item: &T) -> bool
541 where
542 T: PartialOrd<T>,
543 {
544 self.contains(item)
545 }
546}
547
548#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
550pub enum Persistence {
551 None,
553 Loop,
555 Tick,
557 Static,
559 Mutable,
561}
562impl Persistence {
563 pub fn as_state_lifespan_variant(
565 self,
566 subgraph_id: GraphSubgraphId,
567 loop_id: Option<GraphLoopId>,
568 span: Span,
569 ) -> Option<TokenStream> {
570 match self {
571 Persistence::None => {
572 let sg_ident = subgraph_id.as_ident(span);
573 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
574 }
575 Persistence::Loop => {
576 let loop_ident = loop_id
577 .expect("`Persistence::Loop` outside of a loop context.")
578 .as_ident(span);
579 Some(quote_spanned!(span=> Loop(#loop_ident)))
580 }
581 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
582 Persistence::Static => None,
583 Persistence::Mutable => None,
584 }
585 }
586
587 pub fn to_str_lowercase(self) -> &'static str {
589 match self {
590 Persistence::None => "none",
591 Persistence::Tick => "tick",
592 Persistence::Loop => "loop",
593 Persistence::Static => "static",
594 Persistence::Mutable => "mutable",
595 }
596 }
597}
598
599fn make_missing_runtime_msg(op_name: &str) -> Literal {
601 Literal::string(&format!(
602 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
603 op_name
604 ))
605}
606
607#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
609pub enum OperatorCategory {
610 Map,
612 Filter,
614 Flatten,
616 Fold,
618 KeyedFold,
620 LatticeFold,
622 Persistence,
624 MultiIn,
626 MultiOut,
628 Source,
630 Sink,
632 Control,
634 CompilerFusionOperator,
636 Windowing,
638 Unwindowing,
640}
641impl OperatorCategory {
642 pub fn name(self) -> &'static str {
644 self.get_variant_docs().split_once(":").unwrap().0
645 }
646 pub fn description(self) -> &'static str {
648 self.get_variant_docs().split_once(":").unwrap().1
649 }
650}
651
652#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
654pub enum FloType {
655 Source,
657 Windowing,
659 Unwindowing,
661 NextIteration,
663}