csp_ping_pong.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. //! Example for usage of the Chat Server Protocol state machine, doing a real handshake with the
  2. //! chat server and an exemplary payload flow loop.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. #![expect(
  5. clippy::integer_division_remainder_used,
  6. reason = "Some internal of tokio::select triggers this"
  7. )]
  8. use core::time::Duration;
  9. use std::io;
  10. use anyhow::bail;
  11. use clap::Parser;
  12. use libthreema::{
  13. cli::{FullIdentityConfig, FullIdentityConfigOptions},
  14. csp::{
  15. CspProtocol, CspProtocolContext, CspStateUpdate,
  16. payload::{EchoPayload, IncomingPayload, OutgoingFrame, OutgoingPayload},
  17. },
  18. https::cli::https_client_builder,
  19. utils::logging::init_stderr_logging,
  20. };
  21. use tokio::{
  22. io::{AsyncReadExt as _, AsyncWriteExt as _},
  23. net::TcpStream,
  24. signal,
  25. sync::mpsc,
  26. time::{self, Instant},
  27. };
  28. use tracing::{Level, debug, error, info, trace, warn};
  29. #[derive(Parser)]
  30. #[command()]
  31. struct CspPingPongCommand {
  32. #[command(flatten)]
  33. config: FullIdentityConfigOptions,
  34. }
  35. /// Payload queues for the main process
  36. struct PayloadQueuesForCspPingPong {
  37. incoming: mpsc::Receiver<IncomingPayload>,
  38. outgoing: mpsc::Sender<OutgoingPayload>,
  39. }
  40. /// Payload queues for the protocol flow runner
  41. struct PayloadQueuesForCsp {
  42. incoming: mpsc::Sender<IncomingPayload>,
  43. outgoing: mpsc::Receiver<OutgoingPayload>,
  44. }
  45. struct CspProtocolRunner {
  46. /// The TCP stream
  47. stream: TcpStream,
  48. /// An instance of the [`CspProtocol`] state machine
  49. protocol: CspProtocol,
  50. }
  51. impl CspProtocolRunner {
  52. /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
  53. #[tracing::instrument(skip_all)]
  54. async fn new(
  55. server_address: Vec<(String, u16)>,
  56. context: CspProtocolContext,
  57. ) -> anyhow::Result<(Self, OutgoingFrame)> {
  58. // Connect via TCP
  59. debug!(?server_address, "Establishing TCP connection to chat server",);
  60. let tcp_stream = TcpStream::connect(
  61. server_address
  62. .first()
  63. .expect("CSP config should have at least one address"),
  64. )
  65. .await?;
  66. // Create the protocol
  67. let (csp_protocol, client_hello) = CspProtocol::new(context);
  68. Ok((
  69. Self {
  70. stream: tcp_stream,
  71. protocol: csp_protocol,
  72. },
  73. client_hello,
  74. ))
  75. }
  76. /// Do the handshake with the chat server by exchanging the following messages:
  77. ///
  78. /// ```txt
  79. /// C -- client-hello -> S
  80. /// C <- server-hello -- S
  81. /// C ---- login ---- -> S
  82. /// C <-- login-ack ---- S
  83. /// ```
  84. #[tracing::instrument(skip_all)]
  85. async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> anyhow::Result<()> {
  86. // Send the client hello
  87. debug!(length = client_hello.0.len(), "Sending client hello");
  88. self.send(&client_hello.0).await?;
  89. // Handshake by polling the CSP state
  90. for iteration in 1_usize.. {
  91. trace!("Iteration #{iteration}");
  92. // Receive required bytes and add them
  93. let bytes = self.receive_required().await?;
  94. self.protocol.add_chunks(&[&bytes])?;
  95. // Handle instruction
  96. let Some(instruction) = self.protocol.poll()? else {
  97. continue;
  98. };
  99. // We do not expect an incoming payload at this stage
  100. if let Some(incoming_payload) = instruction.incoming_payload {
  101. let message = "Unexpected incoming payload during handshake";
  102. error!(?incoming_payload, message);
  103. bail!(message)
  104. }
  105. // Send any outgoing frame
  106. if let Some(frame) = instruction.outgoing_frame {
  107. self.send(&frame.0).await?;
  108. }
  109. // Check if we've completed the handshake
  110. if let Some(CspStateUpdate::PostHandshake(login_ack_data)) = instruction.state_update {
  111. info!(?login_ack_data, "Handshake complete");
  112. break;
  113. }
  114. }
  115. Ok(())
  116. }
  117. /// Run the payload exchange flow until stopped.
  118. #[tracing::instrument(skip_all)]
  119. async fn run_payload_flow(&mut self, mut queues: PayloadQueuesForCsp) -> anyhow::Result<()> {
  120. let mut read_buffer = [0_u8; 8192];
  121. for iteration in 1_usize.. {
  122. trace!("Payload flow iteration #{iteration}");
  123. // Poll for any pending instruction
  124. let mut instruction = self.protocol.poll()?;
  125. if instruction.is_none() {
  126. // No pending instruction left, wait for more input
  127. instruction = tokio::select! {
  128. // Forward any incoming chunks from the TCP stream
  129. _ = self.stream.readable() => {
  130. let length = self.try_receive(&mut read_buffer)?;
  131. // Add chunks (poll in the next iteration)
  132. self.protocol
  133. .add_chunks(&[read_buffer.get(..length)
  134. .expect("Amount of read bytes should be available")])?;
  135. None
  136. },
  137. // Forward any outgoing payloads
  138. Some(outgoing_payload) = queues.outgoing.recv() => {
  139. debug!(?outgoing_payload, "Sending payload");
  140. Some(self.protocol.create_payload(&outgoing_payload)?)
  141. }
  142. }
  143. }
  144. let Some(instruction) = instruction else {
  145. continue;
  146. };
  147. // We do not expect any state updates at this stage
  148. if let Some(state_update) = instruction.state_update {
  149. let message = "Unexpected state update after handshake";
  150. error!(?state_update, message);
  151. bail!(message)
  152. }
  153. // Forward any incoming payload
  154. if let Some(incoming_payload) = instruction.incoming_payload {
  155. debug!(?incoming_payload, "Received payload");
  156. queues.incoming.send(incoming_payload).await?;
  157. }
  158. // Send any outgoing frame
  159. if let Some(frame) = instruction.outgoing_frame {
  160. self.send(&frame.0).await?;
  161. }
  162. }
  163. Ok(())
  164. }
  165. /// Shut down the TCP connection
  166. #[tracing::instrument(skip_all)]
  167. async fn shutdown(&mut self) -> anyhow::Result<()> {
  168. info!("Shutting down TCP connection");
  169. Ok(self.stream.shutdown().await?)
  170. }
  171. /// Send bytes to the server over the TCP connection
  172. #[tracing::instrument(skip_all, fields(bytes_length = bytes.len()))]
  173. async fn send(&mut self, bytes: &[u8]) -> anyhow::Result<()> {
  174. trace!(length = bytes.len(), "Sending bytes");
  175. self.stream.write_all(bytes).await?;
  176. Ok(())
  177. }
  178. #[tracing::instrument(skip_all)]
  179. async fn receive_required(&mut self) -> anyhow::Result<Vec<u8>> {
  180. // Get the minimum amount of bytes we'll need to receive
  181. let length = self.protocol.next_required_length()?;
  182. let mut buffer = vec![0; length];
  183. trace!(?length, "Reading bytes");
  184. // If there is nothing to read, return immediately
  185. if length == 0 {
  186. return Ok(buffer);
  187. }
  188. // Read the exact number of bytes required
  189. let _ = self.stream.read_exact(&mut buffer).await?;
  190. // Read more if available
  191. match self.stream.try_read_buf(&mut buffer) {
  192. Ok(0) => {
  193. // Remote shut down our reading end. But we still need to process the previously
  194. // read bytes.
  195. warn!("TCP reading end closed");
  196. },
  197. Ok(length) => {
  198. trace!(length, "Got additional bytes");
  199. },
  200. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  201. trace!("No additional bytes available");
  202. },
  203. Err(error) => {
  204. return Err(error.into());
  205. },
  206. }
  207. debug!(length = buffer.len(), "Received bytes");
  208. Ok(buffer)
  209. }
  210. #[tracing::instrument(skip_all)]
  211. fn try_receive(&mut self, buffer: &mut [u8]) -> anyhow::Result<usize> {
  212. match self.stream.try_read(buffer) {
  213. Ok(0) => {
  214. // Remote shut down our reading end gracefully.
  215. //
  216. // IMPORTANT: An implementation needs to ensure that it stops gracefully by processing any
  217. // remaining payloads prior to stopping the protocol. This example implementation ensures this
  218. // by handling all pending instructions prior to polling for more data. The only case we bail
  219. // is therefore when our instruction queue is already dry.
  220. bail!("TCP reading end closed")
  221. },
  222. Ok(length) => {
  223. debug!(length, "Received bytes");
  224. Ok(length)
  225. },
  226. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  227. trace!("No bytes to receive");
  228. Ok(0)
  229. },
  230. Err(error) => Err(error.into()),
  231. }
  232. }
  233. }
  234. #[tracing::instrument(skip_all)]
  235. async fn run_ping_pong_flow(mut queues: PayloadQueuesForCspPingPong) -> anyhow::Result<()> {
  236. // Create the echo timer that will trigger an outgoing payload every 10s
  237. let mut echo_timer = time::interval_at(
  238. Instant::now()
  239. .checked_add(Duration::from_secs(10))
  240. .expect("Oops, apocalypse in 10s"),
  241. Duration::from_secs(10),
  242. );
  243. // Enter ping-pong flow loop
  244. loop {
  245. tokio::select! {
  246. // Send echo-request when the timer fires
  247. _ = echo_timer.tick() => {
  248. let echo_request = OutgoingPayload::EchoRequest(
  249. EchoPayload("ping".as_bytes().to_owned()));
  250. info!(?echo_request, "Sending echo request");
  251. queues.outgoing.send(echo_request).await?;
  252. },
  253. // Process incoming payload
  254. incoming_payload = queues.incoming.recv() => {
  255. if let Some(incoming_payload) = incoming_payload {
  256. info!(?incoming_payload, "Received payload");
  257. } else {
  258. break
  259. }
  260. }
  261. };
  262. }
  263. Ok(())
  264. }
  265. #[tokio::main]
  266. async fn main() -> anyhow::Result<()> {
  267. // Configure logging
  268. init_stderr_logging(Level::TRACE);
  269. // Create HTTP client
  270. let http_client = https_client_builder().build()?;
  271. // Parse command
  272. let arguments = CspPingPongCommand::parse();
  273. let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
  274. // Create payload queues
  275. let (csp_ping_pong_queues, csp_queues) = {
  276. let incoming_payload = mpsc::channel(4);
  277. let outgoing_payload = mpsc::channel(4);
  278. (
  279. PayloadQueuesForCspPingPong {
  280. incoming: incoming_payload.1,
  281. outgoing: outgoing_payload.0,
  282. },
  283. PayloadQueuesForCsp {
  284. incoming: incoming_payload.0,
  285. outgoing: outgoing_payload.1,
  286. },
  287. )
  288. };
  289. // Create CSP protocol and establish a connection
  290. let (mut csp_runner, client_hello) = CspProtocolRunner::new(
  291. config
  292. .minimal
  293. .common
  294. .config
  295. .chat_server_address
  296. .addresses(config.csp_server_group),
  297. config
  298. .csp_context_init()
  299. .try_into()
  300. .expect("Configuration should be valid"),
  301. )
  302. .await?;
  303. // Run the handshake flow
  304. csp_runner.run_handshake_flow(client_hello).await?;
  305. // Run the protocols
  306. tokio::select! {
  307. _ = csp_runner.run_payload_flow(csp_queues) => {},
  308. _ = run_ping_pong_flow(csp_ping_pong_queues) => {},
  309. _ = signal::ctrl_c() => {},
  310. };
  311. // Shut down
  312. csp_runner.shutdown().await?;
  313. Ok(())
  314. }
  315. #[test]
  316. fn verify_cli() {
  317. use clap::CommandFactory;
  318. CspPingPongCommand::command().debug_assert();
  319. }