csp_login.rs 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. //! Example for usage of the Chat Server Protocol state machine, doing a real handshake with the
  2. //! chat server, exiting immediately after successful login.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. use std::io;
  5. use anyhow::bail;
  6. use clap::Parser;
  7. use libthreema::{
  8. cli::{FullIdentityConfig, FullIdentityConfigOptions},
  9. csp::{CspProtocol, CspProtocolContext, CspStateUpdate, payload::OutgoingFrame},
  10. https::cli::https_client_builder,
  11. utils::logging::init_stderr_logging,
  12. };
  13. use tokio::{
  14. io::{AsyncReadExt as _, AsyncWriteExt as _},
  15. net::TcpStream,
  16. };
  17. use tracing::{Level, debug, error, info, trace, warn};
  18. #[derive(Parser)]
  19. #[command()]
  20. struct CspPingPongCommand {
  21. #[command(flatten)]
  22. config: FullIdentityConfigOptions,
  23. }
  24. struct CspProtocolRunner {
  25. /// The TCP stream
  26. stream: TcpStream,
  27. /// An instance of the [`CspProtocol`] state machine
  28. protocol: CspProtocol,
  29. }
  30. impl CspProtocolRunner {
  31. /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
  32. #[tracing::instrument(skip_all)]
  33. async fn new(
  34. server_address: Vec<(String, u16)>,
  35. context: CspProtocolContext,
  36. ) -> anyhow::Result<(Self, OutgoingFrame)> {
  37. // Connect via TCP
  38. debug!(?server_address, "Establishing TCP connection to chat server",);
  39. let tcp_stream = TcpStream::connect(
  40. server_address
  41. .first()
  42. .expect("CSP config should have at least one address"),
  43. )
  44. .await?;
  45. // Create the protocol
  46. let (csp_protocol, client_hello) = CspProtocol::new(context);
  47. Ok((
  48. Self {
  49. stream: tcp_stream,
  50. protocol: csp_protocol,
  51. },
  52. client_hello,
  53. ))
  54. }
  55. /// Do the handshake with the chat server by exchanging the following messages:
  56. ///
  57. /// ```txt
  58. /// C -- client-hello -> S
  59. /// C <- server-hello -- S
  60. /// C ---- login ---- -> S
  61. /// C <-- login-ack ---- S
  62. /// ```
  63. #[tracing::instrument(skip_all)]
  64. async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> anyhow::Result<()> {
  65. // Send the client hello
  66. debug!(length = client_hello.0.len(), "Sending client hello");
  67. self.send(&client_hello.0).await?;
  68. // Handshake by polling the CSP state
  69. for iteration in 1_usize.. {
  70. trace!("Iteration #{iteration}");
  71. // Receive required bytes and add them
  72. let bytes = self.receive_required().await?;
  73. self.protocol.add_chunks(&[&bytes])?;
  74. // Handle instruction
  75. let Some(instruction) = self.protocol.poll()? else {
  76. continue;
  77. };
  78. // We do not expect an incoming payload at this stage
  79. if let Some(incoming_payload) = instruction.incoming_payload {
  80. let message = "Unexpected incoming payload during handshake";
  81. error!(?incoming_payload, message);
  82. bail!(message)
  83. }
  84. // Send any outgoing frame
  85. if let Some(frame) = instruction.outgoing_frame {
  86. self.send(&frame.0).await?;
  87. }
  88. // Check if we've completed the handshake
  89. if let Some(CspStateUpdate::PostHandshake(login_ack_data)) = instruction.state_update {
  90. info!(?login_ack_data, "Handshake complete");
  91. break;
  92. }
  93. }
  94. Ok(())
  95. }
  96. /// Shut down the TCP connection
  97. #[tracing::instrument(skip_all)]
  98. async fn shutdown(&mut self) -> anyhow::Result<()> {
  99. info!("Shutting down TCP connection");
  100. Ok(self.stream.shutdown().await?)
  101. }
  102. /// Send bytes to the server over the TCP connection
  103. #[tracing::instrument(skip_all, fields(bytes_length = bytes.len()))]
  104. async fn send(&mut self, bytes: &[u8]) -> anyhow::Result<()> {
  105. trace!(length = bytes.len(), "Sending bytes");
  106. self.stream.write_all(bytes).await?;
  107. Ok(())
  108. }
  109. #[tracing::instrument(skip_all)]
  110. async fn receive_required(&mut self) -> anyhow::Result<Vec<u8>> {
  111. // Get the minimum amount of bytes we'll need to receive
  112. let length = self.protocol.next_required_length()?;
  113. let mut buffer = vec![0; length];
  114. trace!(?length, "Reading bytes");
  115. // If there is nothing to read, return immediately
  116. if length == 0 {
  117. return Ok(buffer);
  118. }
  119. // Read the exact number of bytes required
  120. let _ = self.stream.read_exact(&mut buffer).await?;
  121. // Read more if available
  122. match self.stream.try_read_buf(&mut buffer) {
  123. Ok(0) => {
  124. // Remote shut down our reading end. But we still need to process the previously
  125. // read bytes.
  126. warn!("TCP reading end closed");
  127. },
  128. Ok(length) => {
  129. trace!(length, "Got additional bytes");
  130. },
  131. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  132. trace!("No additional bytes available");
  133. },
  134. Err(error) => {
  135. return Err(error.into());
  136. },
  137. }
  138. debug!(length = buffer.len(), "Received bytes");
  139. Ok(buffer)
  140. }
  141. }
  142. #[tokio::main]
  143. async fn main() -> anyhow::Result<()> {
  144. // Configure logging
  145. init_stderr_logging(Level::TRACE);
  146. // Create HTTP client
  147. let http_client = https_client_builder().build()?;
  148. // Parse arguments for command
  149. let arguments = CspPingPongCommand::parse();
  150. let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
  151. // Create CSP protocol and establish a connection
  152. let (mut csp_runner, client_hello) = CspProtocolRunner::new(
  153. config
  154. .minimal
  155. .common
  156. .config
  157. .chat_server_address
  158. .addresses(config.csp_server_group),
  159. config
  160. .csp_context_init()
  161. .try_into()
  162. .expect("Configuration should be valid"),
  163. )
  164. .await?;
  165. // Run the handshake flow
  166. csp_runner.run_handshake_flow(client_hello).await?;
  167. // Shut down
  168. csp_runner.shutdown().await?;
  169. Ok(())
  170. }
  171. #[test]
  172. fn verify_cli() {
  173. use clap::CommandFactory;
  174. CspPingPongCommand::command().debug_assert();
  175. }