csp_login.rs 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. //! Example for usage of the Chat Server Protocol state machine, doing a real handshake with the
  2. //! chat server, exiting immediately after successful login.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. use std::io;
  5. use anyhow::bail;
  6. use clap::Parser;
  7. use libthreema::{
  8. cli::{FullIdentityConfig, FullIdentityConfigOptions},
  9. csp::{CspProtocol, CspProtocolContext, CspStateUpdate, frame::OutgoingFrame},
  10. https::cli::https_client_builder,
  11. utils::logging::init_stderr_logging,
  12. };
  13. use tokio::{
  14. io::{AsyncReadExt as _, AsyncWriteExt as _},
  15. net::TcpStream,
  16. };
  17. use tracing::{Level, debug, error, info, trace, warn};
  18. #[derive(Parser)]
  19. #[command()]
  20. struct CspPingPongCommand {
  21. #[command(flatten)]
  22. config: FullIdentityConfigOptions,
  23. }
  24. /// The Client Server Protocol connection handler
  25. struct CspProtocolRunner {
  26. /// The TCP stream
  27. stream: TcpStream,
  28. /// An instance of the [`CspProtocol`] state machine
  29. protocol: CspProtocol,
  30. }
  31. impl CspProtocolRunner {
  32. /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
  33. #[tracing::instrument(skip_all)]
  34. async fn new(
  35. server_address: Vec<(String, u16)>,
  36. context: CspProtocolContext,
  37. ) -> anyhow::Result<(Self, OutgoingFrame)> {
  38. // Connect via TCP
  39. debug!(?server_address, "Establishing TCP connection to chat server",);
  40. let tcp_stream = TcpStream::connect(
  41. server_address
  42. .first()
  43. .expect("CSP config should have at least one address"),
  44. )
  45. .await?;
  46. // Create the protocol
  47. let (csp_protocol, client_hello) = CspProtocol::new(context);
  48. Ok((
  49. Self {
  50. stream: tcp_stream,
  51. protocol: csp_protocol,
  52. },
  53. client_hello,
  54. ))
  55. }
  56. /// Do the handshake with the chat server by exchanging the following messages:
  57. ///
  58. /// ```txt
  59. /// C -- client-hello -> S
  60. /// C <- server-hello -- S
  61. /// C ---- login ---- -> S
  62. /// C <-- login-ack ---- S
  63. /// ```
  64. #[tracing::instrument(skip_all)]
  65. async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> anyhow::Result<()> {
  66. // Send the client hello
  67. debug!(length = client_hello.0.len(), "Sending client hello");
  68. self.send(&client_hello.0).await?;
  69. // Handshake by polling the CSP state
  70. for iteration in 1_usize.. {
  71. trace!("Iteration #{iteration}");
  72. // Receive required bytes and add them
  73. let bytes = self.receive_required().await?;
  74. self.protocol.add_chunks(&[&bytes])?;
  75. // Handle instruction
  76. let Some(instruction) = self.protocol.poll()? else {
  77. continue;
  78. };
  79. // We do not expect an incoming payload at this stage
  80. if let Some(incoming_payload) = instruction.incoming_payload {
  81. let message = "Unexpected incoming payload during handshake";
  82. error!(?incoming_payload, message);
  83. bail!(message)
  84. }
  85. // Send any outgoing frame
  86. if let Some(frame) = instruction.outgoing_frame {
  87. self.send(&frame.0).await?;
  88. }
  89. // Check if we've completed the handshake
  90. if let Some(CspStateUpdate::PostHandshake { queued_messages }) = instruction.state_update {
  91. info!(queued_messages, "Handshake complete");
  92. break;
  93. }
  94. }
  95. Ok(())
  96. }
  97. /// Shut down the TCP connection
  98. #[tracing::instrument(skip_all)]
  99. async fn shutdown(&mut self) -> anyhow::Result<()> {
  100. info!("Shutting down TCP connection");
  101. Ok(self.stream.shutdown().await?)
  102. }
  103. /// Send bytes to the server over the TCP connection
  104. #[tracing::instrument(skip_all, fields(bytes_length = bytes.len()))]
  105. async fn send(&mut self, bytes: &[u8]) -> anyhow::Result<()> {
  106. trace!(length = bytes.len(), "Sending bytes");
  107. self.stream.write_all(bytes).await?;
  108. Ok(())
  109. }
  110. #[tracing::instrument(skip_all)]
  111. async fn receive_required(&mut self) -> anyhow::Result<Vec<u8>> {
  112. // Get the minimum amount of bytes we'll need to receive
  113. let length = self.protocol.next_required_length()?;
  114. let mut buffer = vec![0; length];
  115. trace!(?length, "Reading bytes");
  116. // If there is nothing to read, return immediately
  117. if length == 0 {
  118. return Ok(buffer);
  119. }
  120. // Read the exact number of bytes required
  121. let _ = self.stream.read_exact(&mut buffer).await?;
  122. // Read more if available
  123. match self.stream.try_read_buf(&mut buffer) {
  124. Ok(0) => {
  125. // Remote shut down our reading end. But we still need to process the previously
  126. // read bytes.
  127. warn!("TCP reading end closed");
  128. },
  129. Ok(length) => {
  130. trace!(length, "Got additional bytes");
  131. },
  132. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  133. trace!("No additional bytes available");
  134. },
  135. Err(error) => {
  136. return Err(error.into());
  137. },
  138. }
  139. debug!(length = buffer.len(), "Received bytes");
  140. Ok(buffer)
  141. }
  142. }
  143. #[tokio::main]
  144. async fn main() -> anyhow::Result<()> {
  145. // Configure logging
  146. init_stderr_logging(Level::TRACE);
  147. // Create HTTP client
  148. let http_client = https_client_builder().build()?;
  149. // Parse arguments for command
  150. let arguments = CspPingPongCommand::parse();
  151. let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
  152. // Create CSP protocol and establish a connection
  153. let (mut csp_runner, client_hello) = CspProtocolRunner::new(
  154. config
  155. .minimal
  156. .common
  157. .config
  158. .chat_server_address
  159. .addresses(config.csp_server_group),
  160. config.csp_context().expect("Configuration should be valid"),
  161. )
  162. .await?;
  163. // Run the handshake flow
  164. csp_runner.run_handshake_flow(client_hello).await?;
  165. // Shut down
  166. csp_runner.shutdown().await?;
  167. Ok(())
  168. }
  169. #[test]
  170. fn verify_cli() {
  171. use clap::CommandFactory;
  172. CspPingPongCommand::command().debug_assert();
  173. }