csp_e2e_receive.rs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. //! Example for usage of the Chat Server E2EE Protocol, connecting to the chat server and receiving incoming
  2. //! messages.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. #![expect(
  5. clippy::integer_division_remainder_used,
  6. reason = "Some internal of tokio::select triggers this"
  7. )]
  8. #![expect(
  9. unreachable_code,
  10. unused_variables,
  11. clippy::todo,
  12. reason = "TODO(LIB-16): Finalise this, then remove me"
  13. )]
  14. use core::cell::RefCell;
  15. use std::io;
  16. use anyhow::bail;
  17. use clap::Parser;
  18. use libthreema::{
  19. cli::{FullIdentityConfig, FullIdentityConfigOptions},
  20. common::ClientInfo,
  21. csp::{
  22. CspProtocol, CspProtocolContext, CspProtocolInstruction, CspStateUpdate,
  23. frame::OutgoingFrame,
  24. payload::{IncomingPayload, MessageAck, MessageWithMetadataBox, OutgoingPayload},
  25. },
  26. csp_e2e::{
  27. CspE2eProtocol, CspE2eProtocolContextInit,
  28. contacts::{
  29. create::{CreateContactsInstruction, CreateContactsResponse},
  30. lookup::ContactsLookupResponse,
  31. update::{UpdateContactsInstruction, UpdateContactsResponse},
  32. },
  33. incoming_message::task::{IncomingMessageInstruction, IncomingMessageLoop, IncomingMessageResponse},
  34. reflect::{ReflectInstruction, ReflectResponse},
  35. transaction::{
  36. begin::{BeginTransactionInstruction, BeginTransactionResponse},
  37. commit::{CommitTransactionInstruction, CommitTransactionResponse},
  38. },
  39. },
  40. https::cli::https_client_builder,
  41. model::provider::in_memory::{DefaultShortcutProvider, InMemoryDb, InMemoryDbInit, InMemoryDbSettings},
  42. utils::logging::init_stderr_logging,
  43. };
  44. use tokio::{
  45. io::{AsyncReadExt as _, AsyncWriteExt as _},
  46. net::TcpStream,
  47. signal,
  48. sync::mpsc,
  49. };
  50. use tracing::{Level, debug, error, info, trace, warn};
  51. #[derive(Parser)]
  52. #[command()]
  53. struct CspE2eReceiveCommand {
  54. #[command(flatten)]
  55. config: FullIdentityConfigOptions,
  56. }
  57. enum PayloadForCspE2e {
  58. Message(MessageWithMetadataBox),
  59. MessageAck(MessageAck),
  60. }
  61. impl From<PayloadForCspE2e> for OutgoingPayload {
  62. fn from(payload: PayloadForCspE2e) -> Self {
  63. match payload {
  64. PayloadForCspE2e::Message(message) => OutgoingPayload::MessageWithMetadataBox(message),
  65. PayloadForCspE2e::MessageAck(message_ack) => OutgoingPayload::MessageAck(message_ack),
  66. }
  67. }
  68. }
  69. /// Payload queues for the main process
  70. struct PayloadQueuesForCspE2e {
  71. incoming: mpsc::Receiver<PayloadForCspE2e>,
  72. outgoing: mpsc::Sender<PayloadForCspE2e>,
  73. }
  74. /// Payload queues for the protocol flow runner
  75. struct PayloadQueuesForCsp {
  76. incoming: mpsc::Sender<PayloadForCspE2e>,
  77. outgoing: mpsc::Receiver<PayloadForCspE2e>,
  78. }
  79. /// The Client Server Protocol connection handler
  80. struct CspProtocolRunner {
  81. /// The TCP stream
  82. stream: TcpStream,
  83. /// An instance of the [`CspProtocol`] state machine
  84. protocol: CspProtocol,
  85. }
  86. impl CspProtocolRunner {
  87. /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
  88. #[tracing::instrument(skip_all)]
  89. async fn new(
  90. server_address: Vec<(String, u16)>,
  91. context: CspProtocolContext,
  92. ) -> anyhow::Result<(Self, OutgoingFrame)> {
  93. // Connect via TCP
  94. debug!(?server_address, "Establishing TCP connection to chat server",);
  95. let tcp_stream = TcpStream::connect(
  96. server_address
  97. .first()
  98. .expect("CSP config should have at least one address"),
  99. )
  100. .await?;
  101. // Create the protocol
  102. let (csp_protocol, client_hello) = CspProtocol::new(context);
  103. Ok((
  104. Self {
  105. stream: tcp_stream,
  106. protocol: csp_protocol,
  107. },
  108. client_hello,
  109. ))
  110. }
  111. /// Do the handshake with the chat server by exchanging the following messages:
  112. ///
  113. /// ```txt
  114. /// C -- client-hello -> S
  115. /// C <- server-hello -- S
  116. /// C ---- login ---- -> S
  117. /// C <-- login-ack ---- S
  118. /// ```
  119. #[tracing::instrument(skip_all)]
  120. async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> anyhow::Result<()> {
  121. // Send the client hello
  122. debug!(length = client_hello.0.len(), "Sending client hello");
  123. self.send(&client_hello.0).await?;
  124. // Handshake by polling the CSP state
  125. for iteration in 1_usize.. {
  126. trace!("Iteration #{iteration}");
  127. // Receive required bytes and add them
  128. let bytes = self.receive_required().await?;
  129. self.protocol.add_chunks(&[&bytes])?;
  130. // Handle instruction
  131. let Some(instruction) = self.protocol.poll()? else {
  132. continue;
  133. };
  134. // We do not expect an incoming payload at this stage
  135. if let Some(incoming_payload) = instruction.incoming_payload {
  136. let message = "Unexpected incoming payload during handshake";
  137. error!(?incoming_payload, message);
  138. bail!(message)
  139. }
  140. // Send any outgoing frame
  141. if let Some(frame) = instruction.outgoing_frame {
  142. self.send(&frame.0).await?;
  143. }
  144. // Check if we've completed the handshake
  145. if let Some(CspStateUpdate::PostHandshake { queued_messages }) = instruction.state_update {
  146. info!(queued_messages, "Handshake complete");
  147. break;
  148. }
  149. }
  150. Ok(())
  151. }
  152. /// Run the payload exchange flow until stopped.
  153. #[tracing::instrument(skip_all)]
  154. async fn run_payload_flow(&mut self, mut queues: PayloadQueuesForCsp) -> anyhow::Result<()> {
  155. let mut read_buffer = [0_u8; 8192];
  156. let mut next_instruction: Option<CspProtocolInstruction> = None;
  157. for iteration in 1_usize.. {
  158. trace!("Iteration #{iteration}");
  159. // Poll for an instruction, if necessary
  160. if next_instruction.is_none() {
  161. next_instruction = self.protocol.poll()?;
  162. }
  163. // Wait for more input, if necessary
  164. if next_instruction.is_none() {
  165. next_instruction = tokio::select! {
  166. // Forward any incoming chunks from the TCP stream
  167. _ = self.stream.readable() => {
  168. let length = self.try_receive(&mut read_buffer)?;
  169. // Add chunks (poll in the next iteration)
  170. self.protocol
  171. .add_chunks(&[read_buffer.get(..length)
  172. .expect("Amount of read bytes should be available")])?;
  173. None
  174. }
  175. // Forward any outgoing payloads
  176. outgoing_payload = queues.outgoing.recv() => {
  177. if let Some(outgoing_payload) = outgoing_payload {
  178. let outgoing_payload = OutgoingPayload::from(outgoing_payload);
  179. debug!(?outgoing_payload, "Sending payload");
  180. Some(self.protocol.create_payload(&outgoing_payload)?)
  181. } else {
  182. break
  183. }
  184. }
  185. };
  186. }
  187. // Handle instruction
  188. let Some(current_instruction) = next_instruction.take() else {
  189. continue;
  190. };
  191. // We do not expect any state updates at this stage
  192. if let Some(state_update) = current_instruction.state_update {
  193. let message = "Unexpected state update after handshake";
  194. error!(?state_update, message);
  195. bail!(message)
  196. }
  197. // Handle any incoming payload
  198. if let Some(incoming_payload) = current_instruction.incoming_payload {
  199. debug!(?incoming_payload, "Received payload");
  200. match incoming_payload {
  201. IncomingPayload::EchoRequest(echo_payload) => {
  202. // Respond to echo request
  203. next_instruction = Some(
  204. self.protocol
  205. .create_payload(&OutgoingPayload::EchoResponse(echo_payload))?,
  206. );
  207. },
  208. IncomingPayload::MessageWithMetadataBox(payload) => {
  209. // Forward message
  210. queues.incoming.send(PayloadForCspE2e::Message(payload)).await?;
  211. },
  212. IncomingPayload::MessageAck(payload) => {
  213. // Forward message ack
  214. queues
  215. .incoming
  216. .send(PayloadForCspE2e::MessageAck(payload))
  217. .await?;
  218. },
  219. IncomingPayload::EchoResponse(_)
  220. | IncomingPayload::QueueSendComplete
  221. | IncomingPayload::DeviceCookieChangeIndication
  222. | IncomingPayload::CloseError(_)
  223. | IncomingPayload::ServerAlert(_)
  224. | IncomingPayload::UnknownPayload { .. } => {},
  225. }
  226. }
  227. // Send any outgoing frame
  228. if let Some(frame) = current_instruction.outgoing_frame {
  229. self.send(&frame.0).await?;
  230. }
  231. }
  232. Ok(())
  233. }
  234. /// Shut down the TCP connection
  235. #[tracing::instrument(skip_all)]
  236. async fn shutdown(&mut self) -> anyhow::Result<()> {
  237. info!("Shutting down TCP connection");
  238. Ok(self.stream.shutdown().await?)
  239. }
  240. /// Send bytes to the server over the TCP connection
  241. #[tracing::instrument(skip_all, fields(bytes_length = bytes.len()))]
  242. async fn send(&mut self, bytes: &[u8]) -> anyhow::Result<()> {
  243. trace!(length = bytes.len(), "Sending bytes");
  244. self.stream.write_all(bytes).await?;
  245. Ok(())
  246. }
  247. #[tracing::instrument(skip_all)]
  248. async fn receive_required(&mut self) -> anyhow::Result<Vec<u8>> {
  249. // Get the minimum amount of bytes we'll need to receive
  250. let length = self.protocol.next_required_length()?;
  251. let mut buffer = vec![0; length];
  252. trace!(?length, "Reading bytes");
  253. // If there is nothing to read, return immediately
  254. if length == 0 {
  255. return Ok(buffer);
  256. }
  257. // Read the exact number of bytes required
  258. let _ = self.stream.read_exact(&mut buffer).await?;
  259. // Read more if available
  260. match self.stream.try_read_buf(&mut buffer) {
  261. Ok(0) => {
  262. // Remote shut down our reading end gracefully.
  263. //
  264. // IMPORTANT: An implementation needs to ensure that it stops gracefully by processing any
  265. // remaining payloads prior to stopping the protocol. This example implementation ensures this
  266. // by handling all pending instructions prior to polling for more data. The only case we bail
  267. // is therefore when our instruction queue is already dry.
  268. bail!("TCP reading end closed")
  269. },
  270. Ok(length) => {
  271. trace!(length, "Got additional bytes");
  272. },
  273. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  274. trace!("No additional bytes available");
  275. },
  276. Err(error) => {
  277. return Err(error.into());
  278. },
  279. }
  280. debug!(length = buffer.len(), "Received bytes");
  281. Ok(buffer)
  282. }
  283. #[tracing::instrument(skip_all)]
  284. fn try_receive(&mut self, buffer: &mut [u8]) -> anyhow::Result<usize> {
  285. match self.stream.try_read(buffer) {
  286. Ok(0) => {
  287. // Remote shut down our reading end. But we still need to process the previously
  288. // read bytes.
  289. warn!("TCP reading end closed");
  290. Ok(0)
  291. },
  292. Ok(length) => {
  293. debug!(length, "Received bytes");
  294. Ok(length)
  295. },
  296. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  297. trace!("No bytes to receive");
  298. Ok(0)
  299. },
  300. Err(error) => Err(error.into()),
  301. }
  302. }
  303. }
  304. struct CspE2eProtocolRunner {
  305. /// An instance of the [`CspE2eProtocol`] state machine
  306. protocol: CspE2eProtocol,
  307. /// HTTP client
  308. http_client: reqwest::Client,
  309. }
  310. impl CspE2eProtocolRunner {
  311. #[tracing::instrument(skip_all)]
  312. fn new(http_client: reqwest::Client, context: CspE2eProtocolContextInit) -> anyhow::Result<Self> {
  313. Ok(Self {
  314. protocol: CspE2eProtocol::new(context),
  315. http_client,
  316. })
  317. }
  318. /// Run the receive flow until stopped.
  319. #[tracing::instrument(skip_all)]
  320. async fn run_receive_flow(&mut self, mut queues: PayloadQueuesForCspE2e) -> anyhow::Result<()> {
  321. for iteration in 1_usize.. {
  322. trace!("Receive flow iteration #{iteration}");
  323. // Handle any incoming payloads until we have a task
  324. let mut task = match queues.incoming.recv().await {
  325. Some(PayloadForCspE2e::Message(message)) => self.protocol.handle_incoming_message(message),
  326. Some(PayloadForCspE2e::MessageAck(message_ack)) => {
  327. warn!(?message_ack, "Unexpected message-ack");
  328. continue;
  329. },
  330. None => break,
  331. };
  332. // Handle task
  333. match task.poll(self.protocol.context())? {
  334. IncomingMessageLoop::Instruction(IncomingMessageInstruction::FetchSender(instruction)) => {
  335. // Run both requests simultaneously
  336. let work_directory_request_future = async {
  337. match instruction.work_directory_request {
  338. Some(work_directory_request) => {
  339. work_directory_request.send(&self.http_client).await.map(Some)
  340. },
  341. None => Ok(None),
  342. }
  343. };
  344. let (directory_result, work_directory_result) = tokio::join!(
  345. instruction.directory_request.send(&self.http_client),
  346. work_directory_request_future,
  347. );
  348. // Forward response
  349. task.response(IncomingMessageResponse::FetchSender(ContactsLookupResponse {
  350. directory_result,
  351. work_directory_result: work_directory_result.transpose(),
  352. }))?;
  353. },
  354. IncomingMessageLoop::Instruction(IncomingMessageInstruction::CreateContact(instruction)) => {
  355. match instruction {
  356. CreateContactsInstruction::BeginTransaction(instruction) => {
  357. // Begin transaction and forward response, if any
  358. let response = self.begin_transaction(instruction).await?;
  359. if let Some(response) = response {
  360. task.response(IncomingMessageResponse::CreateContact(
  361. CreateContactsResponse::BeginTransactionResponse(response),
  362. ))?;
  363. }
  364. },
  365. CreateContactsInstruction::ReflectAndCommitTransaction(instruction) => {
  366. // Reflect and commit transaction and forward response
  367. task.response(IncomingMessageResponse::CreateContact(
  368. CreateContactsResponse::CommitTransactionResponse(
  369. self.reflect_and_commit_transaction(instruction).await?,
  370. ),
  371. ))?;
  372. },
  373. }
  374. },
  375. IncomingMessageLoop::Instruction(IncomingMessageInstruction::UpdateContact(instruction)) => {
  376. match instruction {
  377. UpdateContactsInstruction::BeginTransaction(instruction) => {
  378. // Begin transaction and forward response, if any
  379. let response = self.begin_transaction(instruction).await?;
  380. if let Some(response) = response {
  381. task.response(IncomingMessageResponse::UpdateContact(
  382. UpdateContactsResponse::BeginTransactionResponse(response),
  383. ))?;
  384. }
  385. },
  386. UpdateContactsInstruction::ReflectAndCommitTransaction(instruction) => {
  387. // Reflect and commit transaction and forward response
  388. task.response(IncomingMessageResponse::UpdateContact(
  389. UpdateContactsResponse::CommitTransactionResponse(
  390. self.reflect_and_commit_transaction(instruction).await?,
  391. ),
  392. ))?;
  393. },
  394. }
  395. },
  396. IncomingMessageLoop::Instruction(IncomingMessageInstruction::ReflectMessage(instruction)) => {
  397. task.response(IncomingMessageResponse::ReflectMessage(
  398. self.reflect(instruction).await?,
  399. ))?;
  400. },
  401. IncomingMessageLoop::Done(result) => {
  402. // Send message acknowledgement, if any
  403. if let Some(outgoing_message_ack) = result.outgoing_message_ack {
  404. queues
  405. .outgoing
  406. .send(PayloadForCspE2e::MessageAck(outgoing_message_ack))
  407. .await?;
  408. }
  409. // TODO(LIB-16). Enqueue outgoing message task, if any
  410. },
  411. }
  412. }
  413. Ok(())
  414. }
  415. #[tracing::instrument(skip_all)]
  416. async fn begin_transaction(
  417. &self,
  418. instruction: BeginTransactionInstruction,
  419. ) -> anyhow::Result<Option<BeginTransactionResponse>> {
  420. match instruction {
  421. BeginTransactionInstruction::TransactionRejected => {
  422. // TODO(LIB-16). Await TransactionEnded
  423. Ok(None)
  424. },
  425. BeginTransactionInstruction::BeginTransaction { message } => {
  426. // TODO(LIB-16). Send `BeginTransaction, await BeginTransactionAck or TransactionRejected,
  427. // then return BeginTransactionResponse(message)
  428. Ok(Some(BeginTransactionResponse::BeginTransactionReply(todo!())))
  429. },
  430. BeginTransactionInstruction::AbortTransaction { message } => {
  431. // TODO(LIB-16). Send `CommitTransaction`, await CommitTransactionAck, then return
  432. // AbortTransaction(CommitTransactionAck)
  433. Ok(Some(BeginTransactionResponse::AbortTransactionResponse(todo!())))
  434. },
  435. }
  436. }
  437. #[tracing::instrument(skip_all)]
  438. async fn reflect_and_commit_transaction(
  439. &self,
  440. instruction: CommitTransactionInstruction,
  441. ) -> anyhow::Result<CommitTransactionResponse> {
  442. // TODO(LIB-16). Reflect messages, then immediately commit. Await CommitAck and gather any
  443. // reflect-acks
  444. Ok(CommitTransactionResponse {
  445. acknowledged_reflect_ids: todo!(),
  446. commit_transaction_ack: todo!(),
  447. })
  448. }
  449. #[tracing::instrument(skip_all)]
  450. async fn reflect(&self, instruction: ReflectInstruction) -> anyhow::Result<ReflectResponse> {
  451. // TODO(LIB-16). Reflect messages, then wait for corresponding reflect-acks
  452. Ok(ReflectResponse {
  453. acknowledged_reflect_ids: todo!(),
  454. })
  455. }
  456. }
  457. #[tokio::main]
  458. async fn main() -> anyhow::Result<()> {
  459. // Configure logging
  460. init_stderr_logging(Level::TRACE);
  461. // Create HTTP client
  462. let http_client = https_client_builder().build()?;
  463. // Parse arguments for command
  464. let arguments = CspE2eReceiveCommand::parse();
  465. let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
  466. // Create CSP E2EE protocol context
  467. let mut database = InMemoryDb::from(InMemoryDbInit {
  468. user_identity: config.minimal.user_identity,
  469. settings: InMemoryDbSettings {
  470. block_unknown_identities: false,
  471. },
  472. contacts: vec![],
  473. blocked_identities: vec![],
  474. });
  475. let csp_e2e_context = CspE2eProtocolContextInit {
  476. client_info: ClientInfo::Libthreema,
  477. config: config.minimal.common.config.clone(),
  478. csp_e2e: config.csp_e2e_context_init(Box::new(RefCell::new(database.csp_e2e_nonce_provider()))),
  479. d2x: config.d2x_context_init(Box::new(RefCell::new(database.d2d_nonce_provider()))),
  480. shortcut: Box::new(DefaultShortcutProvider),
  481. settings: Box::new(RefCell::new(database.settings_provider())),
  482. contacts: Box::new(RefCell::new(database.contact_provider())),
  483. messages: Box::new(RefCell::new(database.message_provider())),
  484. };
  485. // Create payload queues
  486. let (csp_e2e_queues, csp_queues) = {
  487. let incoming_payload = mpsc::channel(4);
  488. let outgoing_payload = mpsc::channel(4);
  489. (
  490. PayloadQueuesForCspE2e {
  491. incoming: incoming_payload.1,
  492. outgoing: outgoing_payload.0,
  493. },
  494. PayloadQueuesForCsp {
  495. incoming: incoming_payload.0,
  496. outgoing: outgoing_payload.1,
  497. },
  498. )
  499. };
  500. // Create CSP protocol and establish a connection
  501. let (mut csp_runner, client_hello) = CspProtocolRunner::new(
  502. config
  503. .minimal
  504. .common
  505. .config
  506. .chat_server_address
  507. .addresses(config.csp_server_group),
  508. config.csp_context().expect("Configuration should be valid"),
  509. )
  510. .await?;
  511. // Run the handshake flow
  512. csp_runner.run_handshake_flow(client_hello).await?;
  513. // Create CSP E2E protocol
  514. let mut csp_e2e_protocol = CspE2eProtocolRunner::new(http_client, csp_e2e_context)?;
  515. // Run the protocols
  516. tokio::select! {
  517. _ = csp_runner.run_payload_flow(csp_queues) => {},
  518. _ = csp_e2e_protocol.run_receive_flow(csp_e2e_queues) => {},
  519. _ = signal::ctrl_c() => {},
  520. };
  521. // Shut down
  522. csp_runner.shutdown().await?;
  523. Ok(())
  524. }
  525. #[test]
  526. fn verify_cli() {
  527. use clap::CommandFactory;
  528. CspE2eReceiveCommand::command().debug_assert();
  529. }