csp_ping_pong.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. //! Example for usage of the Chat Server Protocol state machine, doing a real handshake with the
  2. //! chat server and an exemplary payload flow loop.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. #![expect(
  5. clippy::integer_division_remainder_used,
  6. reason = "Some internal of tokio::select triggers this"
  7. )]
  8. use core::time::Duration;
  9. use std::io;
  10. use anyhow::bail;
  11. use clap::Parser;
  12. use libthreema::{
  13. cli::{FullIdentityConfig, FullIdentityConfigOptions},
  14. csp::{
  15. CspProtocol, CspProtocolContext, CspStateUpdate,
  16. frame::OutgoingFrame,
  17. payload::{EchoPayload, IncomingPayload, OutgoingPayload},
  18. },
  19. https::cli::https_client_builder,
  20. utils::logging::init_stderr_logging,
  21. };
  22. use tokio::{
  23. io::{AsyncReadExt as _, AsyncWriteExt as _},
  24. net::TcpStream,
  25. signal,
  26. sync::mpsc,
  27. time::{self, Instant},
  28. };
  29. use tracing::{Level, debug, error, info, trace, warn};
  30. #[derive(Parser)]
  31. #[command()]
  32. struct CspPingPongCommand {
  33. #[command(flatten)]
  34. config: FullIdentityConfigOptions,
  35. }
  36. /// Payload queues for the main process
  37. struct PayloadQueuesForCspPingPong {
  38. incoming: mpsc::Receiver<IncomingPayload>,
  39. outgoing: mpsc::Sender<OutgoingPayload>,
  40. }
  41. /// Payload queues for the protocol flow runner
  42. struct PayloadQueuesForCsp {
  43. incoming: mpsc::Sender<IncomingPayload>,
  44. outgoing: mpsc::Receiver<OutgoingPayload>,
  45. }
  46. /// The Client Server Protocol connection handler
  47. struct CspProtocolRunner {
  48. /// The TCP stream
  49. stream: TcpStream,
  50. /// An instance of the [`CspProtocol`] state machine
  51. protocol: CspProtocol,
  52. }
  53. impl CspProtocolRunner {
  54. /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
  55. #[tracing::instrument(skip_all)]
  56. async fn new(
  57. server_address: Vec<(String, u16)>,
  58. context: CspProtocolContext,
  59. ) -> anyhow::Result<(Self, OutgoingFrame)> {
  60. // Connect via TCP
  61. debug!(?server_address, "Establishing TCP connection to chat server",);
  62. let tcp_stream = TcpStream::connect(
  63. server_address
  64. .first()
  65. .expect("CSP config should have at least one address"),
  66. )
  67. .await?;
  68. // Create the protocol
  69. let (csp_protocol, client_hello) = CspProtocol::new(context);
  70. Ok((
  71. Self {
  72. stream: tcp_stream,
  73. protocol: csp_protocol,
  74. },
  75. client_hello,
  76. ))
  77. }
  78. /// Do the handshake with the chat server by exchanging the following messages:
  79. ///
  80. /// ```txt
  81. /// C -- client-hello -> S
  82. /// C <- server-hello -- S
  83. /// C ---- login ---- -> S
  84. /// C <-- login-ack ---- S
  85. /// ```
  86. #[tracing::instrument(skip_all)]
  87. async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> anyhow::Result<()> {
  88. // Send the client hello
  89. debug!(length = client_hello.0.len(), "Sending client hello");
  90. self.send(&client_hello.0).await?;
  91. // Handshake by polling the CSP state
  92. for iteration in 1_usize.. {
  93. trace!("Iteration #{iteration}");
  94. // Receive required bytes and add them
  95. let bytes = self.receive_required().await?;
  96. self.protocol.add_chunks(&[&bytes])?;
  97. // Handle instruction
  98. let Some(instruction) = self.protocol.poll()? else {
  99. continue;
  100. };
  101. // We do not expect an incoming payload at this stage
  102. if let Some(incoming_payload) = instruction.incoming_payload {
  103. let message = "Unexpected incoming payload during handshake";
  104. error!(?incoming_payload, message);
  105. bail!(message)
  106. }
  107. // Send any outgoing frame
  108. if let Some(frame) = instruction.outgoing_frame {
  109. self.send(&frame.0).await?;
  110. }
  111. // Check if we've completed the handshake
  112. if let Some(CspStateUpdate::PostHandshake { queued_messages }) = instruction.state_update {
  113. info!(queued_messages, "Handshake complete");
  114. break;
  115. }
  116. }
  117. Ok(())
  118. }
  119. /// Run the payload exchange flow until stopped.
  120. #[tracing::instrument(skip_all)]
  121. async fn run_payload_flow(&mut self, mut queues: PayloadQueuesForCsp) -> anyhow::Result<()> {
  122. let mut read_buffer = [0_u8; 8192];
  123. for iteration in 1_usize.. {
  124. trace!("Payload flow iteration #{iteration}");
  125. // Poll for any pending instruction
  126. let mut instruction = self.protocol.poll()?;
  127. if instruction.is_none() {
  128. // No pending instruction left, wait for more input
  129. instruction = tokio::select! {
  130. // Forward any incoming chunks from the TCP stream
  131. _ = self.stream.readable() => {
  132. let length = self.try_receive(&mut read_buffer)?;
  133. // Add chunks (poll in the next iteration)
  134. self.protocol
  135. .add_chunks(&[read_buffer.get(..length)
  136. .expect("Amount of read bytes should be available")])?;
  137. None
  138. }
  139. // Forward any outgoing payloads
  140. Some(outgoing_payload) = queues.outgoing.recv() => {
  141. debug!(?outgoing_payload, "Sending payload");
  142. Some(self.protocol.create_payload(&outgoing_payload)?)
  143. }
  144. }
  145. }
  146. let Some(instruction) = instruction else {
  147. continue;
  148. };
  149. // We do not expect any state updates at this stage
  150. if let Some(state_update) = instruction.state_update {
  151. let message = "Unexpected state update after handshake";
  152. error!(?state_update, message);
  153. bail!(message)
  154. }
  155. // Forward any incoming payload
  156. if let Some(incoming_payload) = instruction.incoming_payload {
  157. debug!(?incoming_payload, "Received payload");
  158. queues.incoming.send(incoming_payload).await?;
  159. }
  160. // Send any outgoing frame
  161. if let Some(frame) = instruction.outgoing_frame {
  162. self.send(&frame.0).await?;
  163. }
  164. }
  165. Ok(())
  166. }
  167. /// Shut down the TCP connection
  168. #[tracing::instrument(skip_all)]
  169. async fn shutdown(&mut self) -> anyhow::Result<()> {
  170. info!("Shutting down TCP connection");
  171. Ok(self.stream.shutdown().await?)
  172. }
  173. /// Send bytes to the server over the TCP connection
  174. #[tracing::instrument(skip_all, fields(bytes_length = bytes.len()))]
  175. async fn send(&mut self, bytes: &[u8]) -> anyhow::Result<()> {
  176. trace!(length = bytes.len(), "Sending bytes");
  177. self.stream.write_all(bytes).await?;
  178. Ok(())
  179. }
  180. #[tracing::instrument(skip_all)]
  181. async fn receive_required(&mut self) -> anyhow::Result<Vec<u8>> {
  182. // Get the minimum amount of bytes we'll need to receive
  183. let length = self.protocol.next_required_length()?;
  184. let mut buffer = vec![0; length];
  185. trace!(?length, "Reading bytes");
  186. // If there is nothing to read, return immediately
  187. if length == 0 {
  188. return Ok(buffer);
  189. }
  190. // Read the exact number of bytes required
  191. let _ = self.stream.read_exact(&mut buffer).await?;
  192. // Read more if available
  193. match self.stream.try_read_buf(&mut buffer) {
  194. Ok(0) => {
  195. // Remote shut down our reading end. But we still need to process the previously
  196. // read bytes.
  197. warn!("TCP reading end closed");
  198. },
  199. Ok(length) => {
  200. trace!(length, "Got additional bytes");
  201. },
  202. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  203. trace!("No additional bytes available");
  204. },
  205. Err(error) => {
  206. return Err(error.into());
  207. },
  208. }
  209. debug!(length = buffer.len(), "Received bytes");
  210. Ok(buffer)
  211. }
  212. #[tracing::instrument(skip_all)]
  213. fn try_receive(&mut self, buffer: &mut [u8]) -> anyhow::Result<usize> {
  214. match self.stream.try_read(buffer) {
  215. Ok(0) => {
  216. // Remote shut down our reading end gracefully.
  217. //
  218. // IMPORTANT: An implementation needs to ensure that it stops gracefully by processing any
  219. // remaining payloads prior to stopping the protocol. This example implementation ensures this
  220. // by handling all pending instructions prior to polling for more data. The only case we bail
  221. // is therefore when our instruction queue is already dry.
  222. bail!("TCP reading end closed")
  223. },
  224. Ok(length) => {
  225. debug!(length, "Received bytes");
  226. Ok(length)
  227. },
  228. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  229. trace!("No bytes to receive");
  230. Ok(0)
  231. },
  232. Err(error) => Err(error.into()),
  233. }
  234. }
  235. }
  236. #[tracing::instrument(skip_all)]
  237. async fn run_ping_pong_flow(mut queues: PayloadQueuesForCspPingPong) -> anyhow::Result<()> {
  238. // Create the echo timer that will trigger an outgoing payload every 10s
  239. let mut echo_timer = time::interval_at(
  240. Instant::now()
  241. .checked_add(Duration::from_secs(10))
  242. .expect("Oops, apocalypse in 10s"),
  243. Duration::from_secs(10),
  244. );
  245. // Enter application loop
  246. loop {
  247. tokio::select! {
  248. // Send echo-request when the timer fires
  249. _ = echo_timer.tick() => {
  250. let echo_request = OutgoingPayload::EchoRequest(
  251. EchoPayload("ping".as_bytes().to_owned()));
  252. info!(?echo_request, "Sending echo request");
  253. queues.outgoing.send(echo_request).await?;
  254. }
  255. // Process incoming payload (or stop signal)
  256. incoming_payload = queues.incoming.recv() => {
  257. if let Some(incoming_payload) = incoming_payload {
  258. info!(?incoming_payload, "Received payload");
  259. } else {
  260. break
  261. }
  262. }
  263. };
  264. }
  265. Ok(())
  266. }
  267. #[tokio::main]
  268. async fn main() -> anyhow::Result<()> {
  269. // Configure logging
  270. init_stderr_logging(Level::TRACE);
  271. // Create HTTP client
  272. let http_client = https_client_builder().build()?;
  273. // Parse command
  274. let arguments = CspPingPongCommand::parse();
  275. let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
  276. // Create payload queues
  277. let (csp_ping_pong_queues, csp_queues) = {
  278. let incoming_payload = mpsc::channel(4);
  279. let outgoing_payload = mpsc::channel(4);
  280. (
  281. PayloadQueuesForCspPingPong {
  282. incoming: incoming_payload.1,
  283. outgoing: outgoing_payload.0,
  284. },
  285. PayloadQueuesForCsp {
  286. incoming: incoming_payload.0,
  287. outgoing: outgoing_payload.1,
  288. },
  289. )
  290. };
  291. // Create CSP protocol and establish a connection
  292. let (mut csp_runner, client_hello) = CspProtocolRunner::new(
  293. config
  294. .minimal
  295. .common
  296. .config
  297. .chat_server_address
  298. .addresses(config.csp_server_group),
  299. config.csp_context().expect("Configuration should be valid"),
  300. )
  301. .await?;
  302. // Run the handshake flow
  303. csp_runner.run_handshake_flow(client_hello).await?;
  304. // Run the protocols
  305. tokio::select! {
  306. _ = csp_runner.run_payload_flow(csp_queues) => {},
  307. _ = run_ping_pong_flow(csp_ping_pong_queues) => {},
  308. _ = signal::ctrl_c() => {},
  309. };
  310. // Shut down
  311. csp_runner.shutdown().await?;
  312. Ok(())
  313. }
  314. #[test]
  315. fn verify_cli() {
  316. use clap::CommandFactory;
  317. CspPingPongCommand::command().debug_assert();
  318. }