1#![allow(
2 unused,
3 reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::net::SocketAddr;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::BytesMut;
17use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
18use proc_macro2::Span;
19use sinktools::demux_map_lazy::LazyDemuxSink;
20use sinktools::lazy::{LazySink, LazySource};
21use sinktools::lazy_sink_source::LazySinkSource;
22use stageleft::runtime_support::{
23 FreeVariableWithContext, FreeVariableWithContextWithProps, QuoteTokens,
24};
25use stageleft::{QuotedWithContext, q};
26use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
29use tracing::{debug, instrument, warn};
30
31use crate::location::dynamic::LocationId;
32use crate::location::member_id::TaglessMemberId;
33use crate::location::{LocationKey, MemberId, MembershipEvent};
34
35pub const CHANNEL_MUX_PORT: u16 = 10000;
37
38pub const CHANNEL_MAGIC: u64 = 0x4859_4452_4f5f_4348;
40
41#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
46pub struct ChannelMagic {
47 pub magic: u64,
48}
49
50pub const CHANNEL_PROTOCOL_VERSION: u64 = 1;
52
53#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
58pub struct ChannelProtocolVersion {
59 pub version: u64,
60}
61
62#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
69pub struct ChannelHandshake {
70 pub channel_name: String,
72 pub sender_id: Option<String>,
76}
77
78type MuxConnection = (
80 Option<String>,
81 FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
82);
83
84pub struct ChannelMux {
91 channels: std::sync::Mutex<HashMap<String, tokio::sync::mpsc::UnboundedSender<MuxConnection>>>,
93}
94
95impl Default for ChannelMux {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl ChannelMux {
102 pub fn new() -> Self {
103 Self {
104 channels: std::sync::Mutex::new(HashMap::new()),
105 }
106 }
107
108 pub fn register(
109 &self,
110 channel_name: String,
111 ) -> tokio::sync::mpsc::UnboundedReceiver<MuxConnection> {
112 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
113 let mut channels = self.channels.lock().unwrap();
114 channels.insert(channel_name, tx);
115 rx
116 }
117
118 pub async fn run_accept_loop(self: Arc<Self>, listener: TcpListener) {
119 loop {
120 let (stream, peer) = match listener.accept().await {
121 Ok(v) => v,
122 Err(e) => {
123 warn!(name: "accept_error", error = %e);
124 continue;
125 }
126 };
127 debug!(name: "mux_accepting", ?peer);
128
129 let mux = self.clone();
130 tokio::spawn(async move {
131 let (rx, _tx) = stream.into_split();
132 let mut source = FramedRead::new(rx, LengthDelimitedCodec::new());
133
134 let Some(Ok(magic_frame)) = source.next().await else {
135 warn!(name: "magic_failed", ?peer, "no magic frame");
136 return;
137 };
138
139 let magic: ChannelMagic = match bincode::deserialize(&magic_frame) {
140 Ok(m) => m,
141 Err(e) => {
142 warn!(name: "magic_deserialize_failed", ?peer, error = %e);
143 return;
144 }
145 };
146
147 if magic.magic != CHANNEL_MAGIC {
148 warn!(name: "bad_magic", ?peer, magic = magic.magic, expected = CHANNEL_MAGIC);
149 return;
150 }
151
152 let Some(Ok(version_frame)) = source.next().await else {
153 warn!(name: "version_failed", ?peer, "no version frame");
154 return;
155 };
156
157 let version: ChannelProtocolVersion = match bincode::deserialize(&version_frame) {
158 Ok(v) => v,
159 Err(e) => {
160 warn!(name: "version_deserialize_failed", ?peer, error = %e);
161 return;
162 }
163 };
164
165 if version.version != CHANNEL_PROTOCOL_VERSION {
166 warn!(name: "version_mismatch", ?peer, version = version.version, expected = CHANNEL_PROTOCOL_VERSION);
167 return;
168 }
169
170 let Some(Ok(handshake_frame)) = source.next().await else {
171 warn!(name: "handshake_failed", ?peer, "no handshake frame");
172 return;
173 };
174
175 let handshake: ChannelHandshake = match bincode::deserialize(&handshake_frame) {
176 Ok(h) => h,
177 Err(e) => {
178 warn!(name: "handshake_deserialize_failed", ?peer, error = %e);
179 return;
180 }
181 };
182
183 debug!(name: "handshake_received", ?peer, ?handshake);
184
185 let channels = mux.channels.lock().unwrap();
186 if let Some(tx_chan) = channels.get(&handshake.channel_name) {
187 let _ = tx_chan.send((handshake.sender_id, source));
188 } else {
189 warn!(
190 name: "unknown_channel",
191 channel_name = %handshake.channel_name,
192 ?peer,
193 "no registered consumer for channel"
194 );
195 }
196 });
197 }
198 }
199}
200
201pub fn get_or_init_channel_mux() -> Arc<ChannelMux> {
206 use std::sync::OnceLock;
207 static MUX: OnceLock<Arc<ChannelMux>> = OnceLock::new();
208
209 MUX.get_or_init(|| {
210 let mux = Arc::new(ChannelMux::new());
211 let mux_clone = mux.clone();
212
213 tokio::spawn(async move {
216 let bind_addr = format!("0.0.0.0:{}", CHANNEL_MUX_PORT);
217 debug!(name: "mux_listening", %bind_addr);
218 let listener = TcpListener::bind(&bind_addr)
219 .await
220 .expect("failed to bind channel mux listener");
221 mux_clone.run_accept_loop(listener).await;
222 });
223
224 mux
225 })
226 .clone()
227}
228
229pub async fn send_handshake(
232 sink: &mut FramedWrite<TcpStream, LengthDelimitedCodec>,
233 channel_name: &str,
234 sender_id: Option<&str>,
235) -> Result<(), std::io::Error> {
236 let magic = ChannelMagic {
237 magic: CHANNEL_MAGIC,
238 };
239 sink.send(bytes::Bytes::from(bincode::serialize(&magic).unwrap()))
240 .await?;
241
242 let version = ChannelProtocolVersion {
243 version: CHANNEL_PROTOCOL_VERSION,
244 };
245 sink.send(bytes::Bytes::from(bincode::serialize(&version).unwrap()))
246 .await?;
247
248 let handshake = ChannelHandshake {
249 channel_name: channel_name.to_owned(),
250 sender_id: sender_id.map(|s| s.to_owned()),
251 };
252 sink.send(bytes::Bytes::from(bincode::serialize(&handshake).unwrap()))
253 .await?;
254 Ok(())
255}
256
257pub fn deploy_containerized_o2o(target: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
258 (
259 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
260 async move {
261 let channel_name = channel_name;
262 let target = format!("{}:{}", target, self::CHANNEL_MUX_PORT);
263 debug!(name: "connecting", %target, %channel_name);
264
265 let stream = TcpStream::connect(&target).await?;
266 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
267
268 self::send_handshake(&mut sink, channel_name, None).await?;
269
270 Result::<_, std::io::Error>::Ok(sink)
271 }
272 )))
273 .splice_untyped_ctx(&()),
274 q!(LazySource::new(move || Box::pin(async move {
275 let channel_name = channel_name;
276 let mux = self::get_or_init_channel_mux();
277 let mut rx = mux.register(channel_name.to_owned());
278
279 let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
280 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
281 })?;
282
283 debug!(name: "o2o_channel_connected", %channel_name);
284
285 Result::<_, std::io::Error>::Ok(source)
286 })))
287 .splice_untyped_ctx(&()),
288 )
289}
290
291pub fn deploy_containerized_o2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
292 (
293 q!(sinktools::demux_map_lazy::<_, _, _, _>(
294 move |key: &TaglessMemberId| {
295 let key = key.clone();
296 let channel_name = channel_name.to_owned();
297
298 LazySink::<_, _, _, bytes::Bytes>::new(move || {
299 Box::pin(async move {
300 let target =
301 format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
302 debug!(name: "connecting", %target, channel_name = %channel_name);
303
304 let stream = TcpStream::connect(&target).await?;
305 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
306
307 self::send_handshake(&mut sink, &channel_name, None).await?;
308
309 Result::<_, std::io::Error>::Ok(sink)
310 })
311 })
312 }
313 ))
314 .splice_untyped_ctx(&()),
315 q!(LazySource::new(move || Box::pin(async move {
316 let channel_name = channel_name;
317 let mux = self::get_or_init_channel_mux();
318 let mut rx = mux.register(channel_name.to_owned());
319
320 let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
321 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
322 })?;
323
324 debug!(name: "o2m_channel_connected", %channel_name);
325
326 Result::<_, std::io::Error>::Ok(source)
327 })))
328 .splice_untyped_ctx(&()),
329 )
330}
331
332pub fn deploy_containerized_m2o(target_host: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
333 (
334 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
335 Box::pin(async move {
336 let channel_name = channel_name;
337 let target = format!("{}:{}", target_host, self::CHANNEL_MUX_PORT);
338 debug!(name: "connecting", %target, %channel_name);
339
340 let stream = TcpStream::connect(&target).await?;
341 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
342
343 let container_name = std::env::var("CONTAINER_NAME").unwrap();
344 self::send_handshake(&mut sink, channel_name, Some(&container_name)).await?;
345
346 Result::<_, std::io::Error>::Ok(sink)
347 })
348 }))
349 .splice_untyped_ctx(&()),
350 q!(LazySource::new(move || Box::pin(async move {
351 let channel_name = channel_name;
352 let mux = self::get_or_init_channel_mux();
353 let mut rx = mux.register(channel_name.to_owned());
354
355 Result::<_, std::io::Error>::Ok(
356 futures::stream::unfold(rx, |mut rx| {
357 Box::pin(async move {
358 let (sender_id, source) = rx.recv().await?;
359 let from = sender_id.expect("m2o sender must provide container name");
360
361 debug!(name: "m2o_channel_connected", %from);
362
363 Some((
364 source.map(move |v| {
365 v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
366 }),
367 rx,
368 ))
369 })
370 })
371 .flatten_unordered(None),
372 )
373 })))
374 .splice_untyped_ctx(&()),
375 )
376}
377
378pub fn deploy_containerized_m2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
379 (
380 q!(sinktools::demux_map_lazy::<_, _, _, _>(
381 move |key: &TaglessMemberId| {
382 let key = key.clone();
383 let channel_name = channel_name.to_owned();
384
385 LazySink::<_, _, _, bytes::Bytes>::new(move || {
386 Box::pin(async move {
387 let target =
388 format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
389 debug!(name: "connecting", %target, channel_name = %channel_name);
390
391 let stream = TcpStream::connect(&target).await?;
392 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
393
394 let container_name = std::env::var("CONTAINER_NAME").unwrap();
395 self::send_handshake(&mut sink, &channel_name, Some(&container_name))
396 .await?;
397
398 Result::<_, std::io::Error>::Ok(sink)
399 })
400 })
401 }
402 ))
403 .splice_untyped_ctx(&()),
404 q!(LazySource::new(move || Box::pin(async move {
405 let channel_name = channel_name;
406 let mux = self::get_or_init_channel_mux();
407 let mut rx = mux.register(channel_name.to_owned());
408
409 Result::<_, std::io::Error>::Ok(
410 futures::stream::unfold(rx, |mut rx| {
411 Box::pin(async move {
412 let (sender_id, source) = rx.recv().await?;
413 let from = sender_id.expect("m2m sender must provide container name");
414
415 debug!(name: "m2m_channel_connected", %from);
416
417 Some((
418 source.map(move |v| {
419 v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
420 }),
421 rx,
422 ))
423 })
424 })
425 .flatten_unordered(None),
426 )
427 })))
428 .splice_untyped_ctx(&()),
429 )
430}
431
432pub struct SocketIdent {
433 pub socket_ident: syn::Ident,
434}
435
436impl<Ctx> FreeVariableWithContextWithProps<Ctx, ()> for SocketIdent {
437 type O = TcpListener;
438
439 fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
440 where
441 Self: Sized,
442 {
443 let ident = self.socket_ident;
444
445 (
446 QuoteTokens {
447 prelude: None,
448 expr: Some(quote::quote! { #ident }),
449 },
450 (),
451 )
452 }
453}
454
455pub fn deploy_containerized_external_sink_source_ident(socket_ident: syn::Ident) -> syn::Expr {
456 let socket_ident = SocketIdent { socket_ident };
457
458 q!(LazySinkSource::<
459 _,
460 FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
461 FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
462 bytes::Bytes,
463 std::io::Error,
464 >::new(async move {
465 let (stream, peer) = socket_ident.accept().await?;
466 debug!(name: "external accepting", ?peer);
467 let (rx, tx) = stream.into_split();
468
469 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
470 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
471
472 Result::<_, std::io::Error>::Ok((fr, fw))
473 },))
474 .splice_untyped_ctx(&())
475}
476
477pub fn cluster_ids<'a>() -> impl QuotedWithContext<'a, &'a [TaglessMemberId], ()> + Clone {
478 q!(Box::leak(Box::new([TaglessMemberId::from_container_name(
482 "INVALID CONTAINER NAME cluster_ids"
483 )]))
484 .as_slice())
485}
486
487#[cfg(feature = "docker_runtime")]
488pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
489 q!(TaglessMemberId::from_container_name(
490 std::env::var("CONTAINER_NAME").unwrap()
491 ))
492}
493
494#[cfg(feature = "docker_runtime")]
495pub fn cluster_membership_stream<'a>(
496 location_id: &LocationId,
497) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
498{
499 let key = location_id.key();
500
501 q!(Box::new(self::docker_membership_stream(
502 std::env::var("DEPLOYMENT_INSTANCE").unwrap(),
503 key
504 ))
505 as Box<
506 dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
507 >)
508}
509
510#[cfg(feature = "docker_runtime")]
511#[instrument(skip_all, fields(%deployment_instance, %location_key))]
515fn docker_membership_stream(
516 deployment_instance: String,
517 location_key: LocationKey,
518) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
519 use std::collections::HashSet;
520 use std::sync::{Arc, Mutex};
521
522 use bollard::Docker;
523 use bollard::query_parameters::{EventsOptions, ListContainersOptions};
524 use tokio::sync::mpsc;
525
526 let docker = Docker::connect_with_local_defaults()
527 .unwrap()
528 .with_timeout(Duration::from_secs(1));
529
530 let (event_tx, event_rx) = mpsc::unbounded_channel::<(String, MembershipEvent)>();
531
532 let events_docker = docker.clone();
534 let events_deployment_instance = deployment_instance.clone();
535 tokio::spawn(async move {
536 let mut filters = HashMap::new();
537 filters.insert("type".to_owned(), vec!["container".to_owned()]);
538 filters.insert(
539 "event".to_owned(),
540 vec!["start".to_owned(), "die".to_owned()],
541 );
542 let event_options = Some(EventsOptions {
543 filters: Some(filters),
544 ..Default::default()
545 });
546
547 let mut events = events_docker.events(event_options);
548 while let Some(event) = events.next().await {
549 if let Some((name, membership_event)) = event.ok().and_then(|e| {
550 let name = e
551 .actor
552 .as_ref()
553 .and_then(|a| a.attributes.as_ref())
554 .and_then(|attrs| attrs.get("name"))
555 .map(|s| &**s)?;
556
557 if name.contains(format!("{events_deployment_instance}-{location_key}").as_str()) {
558 match e.action.as_deref() {
559 Some("start") => Some((name.to_owned(), MembershipEvent::Joined)),
560 Some("die") => Some((name.to_owned(), MembershipEvent::Left)),
561 _ => None,
562 }
563 } else {
564 None
565 }
566 }) && event_tx.send((name, membership_event)).is_err()
567 {
568 break;
569 }
570 }
571 });
572
573 let seen_joined = Arc::new(Mutex::new(HashSet::<String>::new()));
575 let seen_joined_snapshot = seen_joined.clone();
576 let seen_joined_events = seen_joined;
577
578 let snapshot_stream = futures::stream::once(async move {
580 let mut filters = HashMap::new();
581 filters.insert(
582 "name".to_owned(),
583 vec![format!("{deployment_instance}-{location_key}")],
584 );
585 let options = Some(ListContainersOptions {
586 filters: Some(filters),
587 ..Default::default()
588 });
589
590 docker
591 .list_containers(options)
592 .await
593 .unwrap_or_default()
594 .iter()
595 .filter_map(|c| c.names.as_deref())
596 .filter_map(|names| names.first())
597 .map(|name| name.trim_start_matches('/'))
598 .filter(|&name| seen_joined_snapshot.lock().unwrap().insert(name.to_owned()))
599 .map(|name| (name.to_owned(), MembershipEvent::Joined))
600 .collect::<Vec<_>>()
601 })
602 .flat_map(futures::stream::iter);
603
604 let events_stream = tokio_stream::StreamExt::filter_map(
606 tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx),
607 move |(name, event)| {
608 let mut seen = seen_joined_events.lock().unwrap();
609 match event {
610 MembershipEvent::Joined => {
611 if seen.insert(name.to_owned()) {
612 Some((name, MembershipEvent::Joined))
613 } else {
614 None
615 }
616 }
617 MembershipEvent::Left => seen.take(&name).map(|name| (name, MembershipEvent::Left)),
618 }
619 },
620 );
621
622 Box::pin(
624 snapshot_stream
625 .chain(events_stream)
626 .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
627 .inspect(|(member_id, event)| debug!(name: "membership_event", ?member_id, ?event)),
628 )
629}