Skip to main content

hydro_lang/deploy/
deploy_runtime_containerized.rs

1#![allow(
2    unused,
3    reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::net::SocketAddr;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::BytesMut;
17use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
18use proc_macro2::Span;
19use sinktools::demux_map_lazy::LazyDemuxSink;
20use sinktools::lazy::{LazySink, LazySource};
21use sinktools::lazy_sink_source::LazySinkSource;
22use stageleft::runtime_support::{
23    FreeVariableWithContext, FreeVariableWithContextWithProps, QuoteTokens,
24};
25use stageleft::{QuotedWithContext, q};
26use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
29use tracing::{debug, instrument, warn};
30
31use crate::location::dynamic::LocationId;
32use crate::location::member_id::TaglessMemberId;
33use crate::location::{LocationKey, MemberId, MembershipEvent};
34
35/// The single well-known port that every node listens on.
36pub const CHANNEL_MUX_PORT: u16 = 10000;
37
38/// Magic constant embedded in every [`ChannelMagic`] header.
39pub const CHANNEL_MAGIC: u64 = 0x4859_4452_4f5f_4348;
40
41/// Magic header sent as the very first frame of every channel handshake.
42///
43/// This is a fixed value that never changes across versions, used to confirm
44/// both sides are speaking the same protocol family before anything else.
45#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
46pub struct ChannelMagic {
47    pub magic: u64,
48}
49
50/// Current protocol version for the channel handshake.
51pub const CHANNEL_PROTOCOL_VERSION: u64 = 1;
52
53/// Protocol version sent as the second frame, after [`ChannelMagic`].
54///
55/// Incremented when the handshake format changes. The receiver checks this
56/// to decide how to deserialize the subsequent [`ChannelHandshake`] frame.
57#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
58pub struct ChannelProtocolVersion {
59    pub version: u64,
60}
61
62/// Handshake message sent by the connecting side to identify the channel.
63///
64/// The receiver reads the third frame (after [`ChannelMagic`] and
65/// [`ChannelProtocolVersion`]) to know which logical channel the connection
66/// belongs to, and optionally which cluster member is connecting.
67/// cluster member is connecting.
68#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
69pub struct ChannelHandshake {
70    /// The logical channel name for this connection.
71    pub channel_name: String,
72    /// If the sender is a cluster member, this is its identifier
73    /// (container name for Docker, task ID for ECS, etc.).
74    /// `None` for process-to-process connections.
75    pub sender_id: Option<String>,
76}
77
78/// A dispatched channel connection: optional sender ID and the read stream.
79type MuxConnection = (
80    Option<String>,
81    FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
82);
83
84/// A shared accept loop that listens on a single port and dispatches
85/// incoming connections to the right consumer based on the channel name
86/// sent in the handshake.
87///
88/// Each node creates one of these at startup. Individual channels register
89/// themselves and receive their connection via a mpsc channel.
90pub struct ChannelMux {
91    /// Map from channel name to a sender that delivers accepted connections.
92    channels: std::sync::Mutex<HashMap<String, tokio::sync::mpsc::UnboundedSender<MuxConnection>>>,
93}
94
95impl Default for ChannelMux {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl ChannelMux {
102    pub fn new() -> Self {
103        Self {
104            channels: std::sync::Mutex::new(HashMap::new()),
105        }
106    }
107
108    pub fn register(
109        &self,
110        channel_name: String,
111    ) -> tokio::sync::mpsc::UnboundedReceiver<MuxConnection> {
112        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
113        let mut channels = self.channels.lock().unwrap();
114        channels.insert(channel_name, tx);
115        rx
116    }
117
118    pub async fn run_accept_loop(self: Arc<Self>, listener: TcpListener) {
119        loop {
120            let (stream, peer) = match listener.accept().await {
121                Ok(v) => v,
122                Err(e) => {
123                    warn!(name: "accept_error", error = %e);
124                    continue;
125                }
126            };
127            debug!(name: "mux_accepting", ?peer);
128
129            let mux = self.clone();
130            tokio::spawn(async move {
131                let (rx, _tx) = stream.into_split();
132                let mut source = FramedRead::new(rx, LengthDelimitedCodec::new());
133
134                let Some(Ok(magic_frame)) = source.next().await else {
135                    warn!(name: "magic_failed", ?peer, "no magic frame");
136                    return;
137                };
138
139                let magic: ChannelMagic = match bincode::deserialize(&magic_frame) {
140                    Ok(m) => m,
141                    Err(e) => {
142                        warn!(name: "magic_deserialize_failed", ?peer, error = %e);
143                        return;
144                    }
145                };
146
147                if magic.magic != CHANNEL_MAGIC {
148                    warn!(name: "bad_magic", ?peer, magic = magic.magic, expected = CHANNEL_MAGIC);
149                    return;
150                }
151
152                let Some(Ok(version_frame)) = source.next().await else {
153                    warn!(name: "version_failed", ?peer, "no version frame");
154                    return;
155                };
156
157                let version: ChannelProtocolVersion = match bincode::deserialize(&version_frame) {
158                    Ok(v) => v,
159                    Err(e) => {
160                        warn!(name: "version_deserialize_failed", ?peer, error = %e);
161                        return;
162                    }
163                };
164
165                if version.version != CHANNEL_PROTOCOL_VERSION {
166                    warn!(name: "version_mismatch", ?peer, version = version.version, expected = CHANNEL_PROTOCOL_VERSION);
167                    return;
168                }
169
170                let Some(Ok(handshake_frame)) = source.next().await else {
171                    warn!(name: "handshake_failed", ?peer, "no handshake frame");
172                    return;
173                };
174
175                let handshake: ChannelHandshake = match bincode::deserialize(&handshake_frame) {
176                    Ok(h) => h,
177                    Err(e) => {
178                        warn!(name: "handshake_deserialize_failed", ?peer, error = %e);
179                        return;
180                    }
181                };
182
183                debug!(name: "handshake_received", ?peer, ?handshake);
184
185                let channels = mux.channels.lock().unwrap();
186                if let Some(tx_chan) = channels.get(&handshake.channel_name) {
187                    let _ = tx_chan.send((handshake.sender_id, source));
188                } else {
189                    warn!(
190                        name: "unknown_channel",
191                        channel_name = %handshake.channel_name,
192                        ?peer,
193                        "no registered consumer for channel"
194                    );
195                }
196            });
197        }
198    }
199}
200
201/// Get or initialize the global ChannelMux for this process.
202///
203/// The first call creates the TcpListener and spawns the accept loop.
204/// Subsequent calls return the same `Arc<ChannelMux>`.
205pub fn get_or_init_channel_mux() -> Arc<ChannelMux> {
206    use std::sync::OnceLock;
207    static MUX: OnceLock<Arc<ChannelMux>> = OnceLock::new();
208
209    MUX.get_or_init(|| {
210        let mux = Arc::new(ChannelMux::new());
211        let mux_clone = mux.clone();
212
213        // Spawn the accept loop in a background task.
214        // We use tokio::spawn which requires a runtime to be active.
215        tokio::spawn(async move {
216            let bind_addr = format!("0.0.0.0:{}", CHANNEL_MUX_PORT);
217            debug!(name: "mux_listening", %bind_addr);
218            let listener = TcpListener::bind(&bind_addr)
219                .await
220                .expect("failed to bind channel mux listener");
221            mux_clone.run_accept_loop(listener).await;
222        });
223
224        mux
225    })
226    .clone()
227}
228
229/// Sends a [`ChannelMagic`], then a [`ChannelProtocolVersion`], then a
230/// [`ChannelHandshake`] as three separate frames over the given sink.
231pub async fn send_handshake(
232    sink: &mut FramedWrite<TcpStream, LengthDelimitedCodec>,
233    channel_name: &str,
234    sender_id: Option<&str>,
235) -> Result<(), std::io::Error> {
236    let magic = ChannelMagic {
237        magic: CHANNEL_MAGIC,
238    };
239    sink.send(bytes::Bytes::from(bincode::serialize(&magic).unwrap()))
240        .await?;
241
242    let version = ChannelProtocolVersion {
243        version: CHANNEL_PROTOCOL_VERSION,
244    };
245    sink.send(bytes::Bytes::from(bincode::serialize(&version).unwrap()))
246        .await?;
247
248    let handshake = ChannelHandshake {
249        channel_name: channel_name.to_owned(),
250        sender_id: sender_id.map(|s| s.to_owned()),
251    };
252    sink.send(bytes::Bytes::from(bincode::serialize(&handshake).unwrap()))
253        .await?;
254    Ok(())
255}
256
257pub fn deploy_containerized_o2o(target: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
258    (
259        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
260            async move {
261                let channel_name = channel_name;
262                let target = format!("{}:{}", target, self::CHANNEL_MUX_PORT);
263                debug!(name: "connecting", %target, %channel_name);
264
265                let stream = TcpStream::connect(&target).await?;
266                let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
267
268                self::send_handshake(&mut sink, channel_name, None).await?;
269
270                Result::<_, std::io::Error>::Ok(sink)
271            }
272        )))
273        .splice_untyped_ctx(&()),
274        q!(LazySource::new(move || Box::pin(async move {
275            let channel_name = channel_name;
276            let mux = self::get_or_init_channel_mux();
277            let mut rx = mux.register(channel_name.to_owned());
278
279            let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
280                std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
281            })?;
282
283            debug!(name: "o2o_channel_connected", %channel_name);
284
285            Result::<_, std::io::Error>::Ok(source)
286        })))
287        .splice_untyped_ctx(&()),
288    )
289}
290
291pub fn deploy_containerized_o2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
292    (
293        q!(sinktools::demux_map_lazy::<_, _, _, _>(
294            move |key: &TaglessMemberId| {
295                let key = key.clone();
296                let channel_name = channel_name.to_owned();
297
298                LazySink::<_, _, _, bytes::Bytes>::new(move || {
299                    Box::pin(async move {
300                        let target =
301                            format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
302                        debug!(name: "connecting", %target, channel_name = %channel_name);
303
304                        let stream = TcpStream::connect(&target).await?;
305                        let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
306
307                        self::send_handshake(&mut sink, &channel_name, None).await?;
308
309                        Result::<_, std::io::Error>::Ok(sink)
310                    })
311                })
312            }
313        ))
314        .splice_untyped_ctx(&()),
315        q!(LazySource::new(move || Box::pin(async move {
316            let channel_name = channel_name;
317            let mux = self::get_or_init_channel_mux();
318            let mut rx = mux.register(channel_name.to_owned());
319
320            let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
321                std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
322            })?;
323
324            debug!(name: "o2m_channel_connected", %channel_name);
325
326            Result::<_, std::io::Error>::Ok(source)
327        })))
328        .splice_untyped_ctx(&()),
329    )
330}
331
332pub fn deploy_containerized_m2o(target_host: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
333    (
334        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
335            Box::pin(async move {
336                let channel_name = channel_name;
337                let target = format!("{}:{}", target_host, self::CHANNEL_MUX_PORT);
338                debug!(name: "connecting", %target, %channel_name);
339
340                let stream = TcpStream::connect(&target).await?;
341                let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
342
343                let container_name = std::env::var("CONTAINER_NAME").unwrap();
344                self::send_handshake(&mut sink, channel_name, Some(&container_name)).await?;
345
346                Result::<_, std::io::Error>::Ok(sink)
347            })
348        }))
349        .splice_untyped_ctx(&()),
350        q!(LazySource::new(move || Box::pin(async move {
351            let channel_name = channel_name;
352            let mux = self::get_or_init_channel_mux();
353            let mut rx = mux.register(channel_name.to_owned());
354
355            Result::<_, std::io::Error>::Ok(
356                futures::stream::unfold(rx, |mut rx| {
357                    Box::pin(async move {
358                        let (sender_id, source) = rx.recv().await?;
359                        let from = sender_id.expect("m2o sender must provide container name");
360
361                        debug!(name: "m2o_channel_connected", %from);
362
363                        Some((
364                            source.map(move |v| {
365                                v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
366                            }),
367                            rx,
368                        ))
369                    })
370                })
371                .flatten_unordered(None),
372            )
373        })))
374        .splice_untyped_ctx(&()),
375    )
376}
377
378pub fn deploy_containerized_m2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
379    (
380        q!(sinktools::demux_map_lazy::<_, _, _, _>(
381            move |key: &TaglessMemberId| {
382                let key = key.clone();
383                let channel_name = channel_name.to_owned();
384
385                LazySink::<_, _, _, bytes::Bytes>::new(move || {
386                    Box::pin(async move {
387                        let target =
388                            format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
389                        debug!(name: "connecting", %target, channel_name = %channel_name);
390
391                        let stream = TcpStream::connect(&target).await?;
392                        let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
393
394                        let container_name = std::env::var("CONTAINER_NAME").unwrap();
395                        self::send_handshake(&mut sink, &channel_name, Some(&container_name))
396                            .await?;
397
398                        Result::<_, std::io::Error>::Ok(sink)
399                    })
400                })
401            }
402        ))
403        .splice_untyped_ctx(&()),
404        q!(LazySource::new(move || Box::pin(async move {
405            let channel_name = channel_name;
406            let mux = self::get_or_init_channel_mux();
407            let mut rx = mux.register(channel_name.to_owned());
408
409            Result::<_, std::io::Error>::Ok(
410                futures::stream::unfold(rx, |mut rx| {
411                    Box::pin(async move {
412                        let (sender_id, source) = rx.recv().await?;
413                        let from = sender_id.expect("m2m sender must provide container name");
414
415                        debug!(name: "m2m_channel_connected", %from);
416
417                        Some((
418                            source.map(move |v| {
419                                v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
420                            }),
421                            rx,
422                        ))
423                    })
424                })
425                .flatten_unordered(None),
426            )
427        })))
428        .splice_untyped_ctx(&()),
429    )
430}
431
432pub struct SocketIdent {
433    pub socket_ident: syn::Ident,
434}
435
436impl<Ctx> FreeVariableWithContextWithProps<Ctx, ()> for SocketIdent {
437    type O = TcpListener;
438
439    fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
440    where
441        Self: Sized,
442    {
443        let ident = self.socket_ident;
444
445        (
446            QuoteTokens {
447                prelude: None,
448                expr: Some(quote::quote! { #ident }),
449            },
450            (),
451        )
452    }
453}
454
455pub fn deploy_containerized_external_sink_source_ident(socket_ident: syn::Ident) -> syn::Expr {
456    let socket_ident = SocketIdent { socket_ident };
457
458    q!(LazySinkSource::<
459        _,
460        FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
461        FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
462        bytes::Bytes,
463        std::io::Error,
464    >::new(async move {
465        let (stream, peer) = socket_ident.accept().await?;
466        debug!(name: "external accepting", ?peer);
467        let (rx, tx) = stream.into_split();
468
469        let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
470        let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
471
472        Result::<_, std::io::Error>::Ok((fr, fw))
473    },))
474    .splice_untyped_ctx(&())
475}
476
477pub fn cluster_ids<'a>() -> impl QuotedWithContext<'a, &'a [TaglessMemberId], ()> + Clone {
478    // unimplemented!(); // this is unused.
479
480    // This is a dummy piece of code, since clusters are dynamic when containerized.
481    q!(Box::leak(Box::new([TaglessMemberId::from_container_name(
482        "INVALID CONTAINER NAME cluster_ids"
483    )]))
484    .as_slice())
485}
486
487#[cfg(feature = "docker_runtime")]
488pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
489    q!(TaglessMemberId::from_container_name(
490        std::env::var("CONTAINER_NAME").unwrap()
491    ))
492}
493
494#[cfg(feature = "docker_runtime")]
495pub fn cluster_membership_stream<'a>(
496    location_id: &LocationId,
497) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
498{
499    let key = location_id.key();
500
501    q!(Box::new(self::docker_membership_stream(
502        std::env::var("DEPLOYMENT_INSTANCE").unwrap(),
503        key
504    ))
505        as Box<
506            dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
507        >)
508}
509
510#[cfg(feature = "docker_runtime")]
511// There's a risk of race conditions here since all the containers will be starting up at the same time.
512// So we need to start listening for events and the take a snapshot of currently running containers, since they may have already started up before we started listening to events.
513// Then we need to turn that into a usable stream for the consumer in this current hydro program. The way you do that is by emitting from the snapshot first, and then start emitting from the stream. Keep a hash set around to track whether a container is up or down.
514#[instrument(skip_all, fields(%deployment_instance, %location_key))]
515fn docker_membership_stream(
516    deployment_instance: String,
517    location_key: LocationKey,
518) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
519    use std::collections::HashSet;
520    use std::sync::{Arc, Mutex};
521
522    use bollard::Docker;
523    use bollard::query_parameters::{EventsOptions, ListContainersOptions};
524    use tokio::sync::mpsc;
525
526    let docker = Docker::connect_with_local_defaults()
527        .unwrap()
528        .with_timeout(Duration::from_secs(1));
529
530    let (event_tx, event_rx) = mpsc::unbounded_channel::<(String, MembershipEvent)>();
531
532    // 1. Start event subscription in a spawned task
533    let events_docker = docker.clone();
534    let events_deployment_instance = deployment_instance.clone();
535    tokio::spawn(async move {
536        let mut filters = HashMap::new();
537        filters.insert("type".to_owned(), vec!["container".to_owned()]);
538        filters.insert(
539            "event".to_owned(),
540            vec!["start".to_owned(), "die".to_owned()],
541        );
542        let event_options = Some(EventsOptions {
543            filters: Some(filters),
544            ..Default::default()
545        });
546
547        let mut events = events_docker.events(event_options);
548        while let Some(event) = events.next().await {
549            if let Some((name, membership_event)) = event.ok().and_then(|e| {
550                let name = e
551                    .actor
552                    .as_ref()
553                    .and_then(|a| a.attributes.as_ref())
554                    .and_then(|attrs| attrs.get("name"))
555                    .map(|s| &**s)?;
556
557                if name.contains(format!("{events_deployment_instance}-{location_key}").as_str()) {
558                    match e.action.as_deref() {
559                        Some("start") => Some((name.to_owned(), MembershipEvent::Joined)),
560                        Some("die") => Some((name.to_owned(), MembershipEvent::Left)),
561                        _ => None,
562                    }
563                } else {
564                    None
565                }
566            }) && event_tx.send((name, membership_event)).is_err()
567            {
568                break;
569            }
570        }
571    });
572
573    // Shared state for deduplication across snapshot and events phases
574    let seen_joined = Arc::new(Mutex::new(HashSet::<String>::new()));
575    let seen_joined_snapshot = seen_joined.clone();
576    let seen_joined_events = seen_joined;
577
578    // 2. Snapshot stream - fetch current containers and emit Joined events
579    let snapshot_stream = futures::stream::once(async move {
580        let mut filters = HashMap::new();
581        filters.insert(
582            "name".to_owned(),
583            vec![format!("{deployment_instance}-{location_key}")],
584        );
585        let options = Some(ListContainersOptions {
586            filters: Some(filters),
587            ..Default::default()
588        });
589
590        docker
591            .list_containers(options)
592            .await
593            .unwrap_or_default()
594            .iter()
595            .filter_map(|c| c.names.as_deref())
596            .filter_map(|names| names.first())
597            .map(|name| name.trim_start_matches('/'))
598            .filter(|&name| seen_joined_snapshot.lock().unwrap().insert(name.to_owned()))
599            .map(|name| (name.to_owned(), MembershipEvent::Joined))
600            .collect::<Vec<_>>()
601    })
602    .flat_map(futures::stream::iter);
603
604    // 3. Events stream - process live events with deduplication
605    let events_stream = tokio_stream::StreamExt::filter_map(
606        tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx),
607        move |(name, event)| {
608            let mut seen = seen_joined_events.lock().unwrap();
609            match event {
610                MembershipEvent::Joined => {
611                    if seen.insert(name.to_owned()) {
612                        Some((name, MembershipEvent::Joined))
613                    } else {
614                        None
615                    }
616                }
617                MembershipEvent::Left => seen.take(&name).map(|name| (name, MembershipEvent::Left)),
618            }
619        },
620    );
621
622    // 4. Chain snapshot then events
623    Box::pin(
624        snapshot_stream
625            .chain(events_stream)
626            .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
627            .inspect(|(member_id, event)| debug!(name: "membership_event", ?member_id, ?event)),
628    )
629}