csp.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. //! Example for usage of the Chat Server Protocol state machine, doing a real handshake with the
  2. //! chat server and an exemplary payload flow loop.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. #![expect(
  5. clippy::integer_division_remainder_used,
  6. reason = "Some internal of tokio::select triggers this"
  7. )]
  8. use core::{net::SocketAddr, time::Duration};
  9. use std::io;
  10. use anyhow::{Result, bail};
  11. use clap::Parser;
  12. use libthreema::{
  13. common::{PublicKey, RawClientKey, ThreemaId},
  14. csp::{
  15. Context, CspProtocol, CspStateUpdate,
  16. frame::OutgoingFrame,
  17. payload::{EchoPayload, IncomingPayload, OutgoingPayload},
  18. },
  19. utils::logging::init_stderr_logging,
  20. };
  21. use tokio::{
  22. io::{AsyncReadExt as _, AsyncWriteExt as _},
  23. net::TcpStream,
  24. signal,
  25. sync::mpsc,
  26. task,
  27. time::{self, Instant},
  28. };
  29. use tracing::{Level, debug, error, info, trace, warn};
  30. /// Fulfill a handshake with the chat server
  31. #[derive(Parser)]
  32. #[command()]
  33. struct Main {
  34. /// Address of the server, e.g., 1.2.3.4:80
  35. #[arg(long)]
  36. server_address: SocketAddr,
  37. /// The server's public key
  38. #[arg(
  39. long,
  40. required = true,
  41. num_args = 1..,
  42. value_delimiter = ',',
  43. value_parser = PublicKey::from_hex_cli
  44. )]
  45. permanent_server_key: Vec<PublicKey>,
  46. /// Threema ID
  47. #[arg(short, long, value_parser = ThreemaId::from_str_cli)]
  48. threema_id: ThreemaId,
  49. /// Client key (32 bytes base64 encoded)
  50. #[arg(short, long, value_parser = RawClientKey::from_hex_cli)]
  51. client_key: RawClientKey,
  52. }
  53. impl Main {
  54. /// Parse arguments to context and server address
  55. fn parse_context_server_address() -> (Context, SocketAddr) {
  56. let main = Main::parse();
  57. let context = Context::new(
  58. main.permanent_server_key,
  59. main.threema_id,
  60. main.client_key.into(),
  61. "libthreema;example;de/ch;testing".to_owned(),
  62. None,
  63. None,
  64. )
  65. .expect("permanent_server_key should not be empty");
  66. debug!(?context, "Starting protocol");
  67. (context, main.server_address)
  68. }
  69. }
  70. /// Payload queues for the main process
  71. struct PayloadQueuesForMain {
  72. incoming: mpsc::Receiver<IncomingPayload>,
  73. outgoing: mpsc::Sender<OutgoingPayload>,
  74. }
  75. /// Payload queues for the protocol flow runner
  76. struct PayloadQueuesForProtocol {
  77. incoming: mpsc::Sender<IncomingPayload>,
  78. outgoing: mpsc::Receiver<OutgoingPayload>,
  79. }
  80. /// The Client Server Protocol connection handler
  81. struct CspConnection {
  82. /// The TCP stream
  83. tcp_stream: TcpStream,
  84. /// An instance of the [`CspProtocol`] state machine
  85. protocol: CspProtocol,
  86. }
  87. impl CspConnection {
  88. /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
  89. pub(crate) async fn new(server_address: SocketAddr, context: Context) -> Result<(Self, OutgoingFrame)> {
  90. // Connect via TCP
  91. debug!(?server_address, "Establishing TCP connection to chat server",);
  92. let tcp_stream = TcpStream::connect(server_address).await?;
  93. // Create the protocol
  94. let (csp_protocol, client_hello) = CspProtocol::new(context);
  95. Ok((
  96. Self {
  97. tcp_stream,
  98. protocol: csp_protocol,
  99. },
  100. client_hello,
  101. ))
  102. }
  103. /// Do the handshake with the chat server by exchanging the following messages:
  104. ///
  105. /// ```txt
  106. /// C -- client-hello -> S
  107. /// C <- server-hello -- S
  108. /// C ---- login ---- -> S
  109. /// C <-- login-ack ---- S
  110. /// ```
  111. pub(crate) async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> Result<()> {
  112. // Send the client hello
  113. debug!(length = client_hello.0.len(), "Sending client hello");
  114. self.send(&client_hello.0).await?;
  115. // Handshake by polling the CSP state
  116. for iteration in 1_usize.. {
  117. trace!("Handshake flow iteration #{iteration}");
  118. // Receive required bytes and add them
  119. let bytes = self.receive_required().await?;
  120. self.protocol.add_chunks(&[&bytes])?;
  121. // Handle instruction
  122. let Some(instruction) = self.protocol.poll()? else {
  123. continue;
  124. };
  125. // We do not expect an incoming payload at this stage
  126. if let Some(incoming_payload) = instruction.incoming_payload {
  127. let message = "Unexpected incoming payload during handshake";
  128. error!(?incoming_payload, message);
  129. bail!(message)
  130. }
  131. // Send any outgoing frame
  132. if let Some(frame) = instruction.outgoing_frame {
  133. self.send(&frame.0).await?;
  134. }
  135. // Check if we've completed the handshake
  136. if let Some(CspStateUpdate::PostHandshake { queued_messages }) = instruction.state_update {
  137. info!(queued_messages, "Handshake complete");
  138. break;
  139. }
  140. }
  141. Ok(())
  142. }
  143. /// Run the payload exchange flow until stopped.
  144. pub(crate) async fn run_payload_flow(&mut self, mut queues: PayloadQueuesForProtocol) -> Result<()> {
  145. let mut read_buffer = [0_u8; 8192];
  146. for iteration in 1_usize.. {
  147. trace!("Payload flow iteration #{iteration}");
  148. // Poll for any pending instruction
  149. let mut instruction = self.protocol.poll()?;
  150. if instruction.is_none() {
  151. // No pending instruction left, wait for more input
  152. instruction = tokio::select! {
  153. // Forward any incoming chunks from the TCP stream
  154. _ = self.tcp_stream.readable() => {
  155. let length = self.try_receive(&mut read_buffer)?;
  156. // Add chunks (poll in the next iteration)
  157. self.protocol
  158. .add_chunks(&[read_buffer.get(..length)
  159. .expect("Amount of read bytes should be available")])?;
  160. None
  161. }
  162. // Forward any outgoing payloads
  163. Some(outgoing_payload) = queues.outgoing.recv() => {
  164. debug!(?outgoing_payload, "Sending payload");
  165. Some(self.protocol.create_payload(&outgoing_payload)?)
  166. }
  167. }
  168. }
  169. let Some(instruction) = instruction else {
  170. continue;
  171. };
  172. // We do not expect any state updates at this stage
  173. if let Some(state_update) = instruction.state_update {
  174. let message = "Unexpected state update after handshake";
  175. error!(?state_update, message);
  176. bail!(message)
  177. }
  178. // Log any incoming payload
  179. if let Some(incoming_payload) = instruction.incoming_payload {
  180. debug!(?incoming_payload, "Received payload");
  181. queues.incoming.send(incoming_payload).await?;
  182. }
  183. // Send any outgoing frame
  184. if let Some(frame) = instruction.outgoing_frame {
  185. self.send(&frame.0).await?;
  186. }
  187. }
  188. Ok(())
  189. }
  190. /// Shut down the TCP connection
  191. pub(crate) async fn shutdown(&mut self) -> Result<()> {
  192. info!("Shutting down TCP connection");
  193. Ok(self.tcp_stream.shutdown().await?)
  194. }
  195. /// Send bytes to the server over the TCP connection
  196. async fn send(&mut self, bytes: &[u8]) -> Result<()> {
  197. trace!(length = bytes.len(), "Sending bytes");
  198. self.tcp_stream.write_all(bytes).await?;
  199. Ok(())
  200. }
  201. async fn receive_required(&mut self) -> Result<Vec<u8>> {
  202. // Get the minimum amount of bytes we'll need to receive
  203. let length = self.protocol.next_required_length()?;
  204. let mut buffer = vec![0; length];
  205. trace!(?length, "Reading bytes");
  206. // If there is nothing to read, return immediately
  207. if length == 0 {
  208. return Ok(buffer);
  209. }
  210. // Read the exact number of bytes required
  211. let _ = self.tcp_stream.read_exact(&mut buffer).await?;
  212. // Read more if available
  213. match self.tcp_stream.try_read_buf(&mut buffer) {
  214. Ok(0) => {
  215. // Remote shut down our reading end. But we still need to process the previously
  216. // read bytes.
  217. warn!("TCP reading end closed");
  218. },
  219. Ok(length) => {
  220. trace!(length, "Got additional bytes");
  221. },
  222. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  223. trace!("No additional bytes available");
  224. },
  225. Err(error) => {
  226. return Err(error.into());
  227. },
  228. }
  229. debug!(length = buffer.len(), "Received bytes");
  230. Ok(buffer)
  231. }
  232. fn try_receive(&mut self, buffer: &mut [u8]) -> Result<usize> {
  233. match self.tcp_stream.try_read(buffer) {
  234. Ok(0) => {
  235. // Remote shut down our reading end. But we still need to process the previously
  236. // read bytes.
  237. warn!("TCP reading end closed");
  238. Ok(0)
  239. },
  240. Ok(length) => {
  241. debug!(length, "Received bytes");
  242. Ok(length)
  243. },
  244. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  245. trace!("No bytes to receive");
  246. Ok(0)
  247. },
  248. Err(error) => Err(error.into()),
  249. }
  250. }
  251. }
  252. async fn run_app_flow(mut queues: PayloadQueuesForMain) -> Result<()> {
  253. // Create the echo timer that will trigger an outgoing payload every 10s
  254. let mut echo_timer = time::interval_at(
  255. Instant::now()
  256. .checked_add(Duration::from_secs(10))
  257. .expect("Oops, apocalypse in 10s"),
  258. Duration::from_secs(10),
  259. );
  260. // Enter application loop
  261. loop {
  262. tokio::select! {
  263. // Send echo-request when the timer fires
  264. _ = echo_timer.tick() => {
  265. let echo_request = OutgoingPayload::EchoRequest(
  266. EchoPayload("ping".as_bytes().to_owned()));
  267. info!(?echo_request, "Sending echo request");
  268. if queues.outgoing.send(echo_request).await.is_err() {
  269. info!("Stopping app");
  270. return Ok(())
  271. }
  272. }
  273. // Process incoming payload (or stop signal)
  274. incoming_payload = queues.incoming.recv() => {
  275. if let Some(incoming_payload) = incoming_payload {
  276. info!(?incoming_payload, "Received payload");
  277. } else {
  278. info!("Stopping app");
  279. return Ok(())
  280. }
  281. }
  282. };
  283. }
  284. }
  285. #[tokio::main]
  286. async fn main() -> Result<()> {
  287. // Configure logging
  288. init_stderr_logging(Level::TRACE);
  289. // Parse arguments for command
  290. let (context, server_address) = Main::parse_context_server_address();
  291. // Create payload queues
  292. let (app_queues, protocol_queues) = {
  293. let incoming_payload = mpsc::channel(4);
  294. let outgoing_payload = mpsc::channel(4);
  295. (
  296. PayloadQueuesForMain {
  297. incoming: incoming_payload.1,
  298. outgoing: outgoing_payload.0,
  299. },
  300. PayloadQueuesForProtocol {
  301. incoming: incoming_payload.0,
  302. outgoing: outgoing_payload.1,
  303. },
  304. )
  305. };
  306. // Create protocol connection
  307. let (mut csp_connection, client_hello) = CspConnection::new(server_address, context).await?;
  308. // Run the handshake flow
  309. csp_connection.run_handshake_flow(client_hello).await?;
  310. // Spawn a task that simulates a payload sender/receiver flow typical for an application
  311. let app_handle = task::spawn(run_app_flow(app_queues));
  312. // Run the payload flow
  313. tokio::select! {
  314. _ = csp_connection.run_payload_flow(protocol_queues) => {}
  315. _ = signal::ctrl_c() => {}
  316. };
  317. // Shut down
  318. app_handle.await??;
  319. csp_connection.shutdown().await?;
  320. Ok(())
  321. }
  322. #[test]
  323. fn verify_cli() {
  324. use clap::CommandFactory;
  325. Main::command().debug_assert();
  326. }