csp_e2e_receive.rs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. //! Example for usage of the Chat Server E2EE Protocol, connecting to the chat server and receiving incoming
  2. //! messages.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. #![expect(
  5. clippy::integer_division_remainder_used,
  6. reason = "Some internal of tokio::select triggers this"
  7. )]
  8. #![expect(
  9. unreachable_code,
  10. unused_variables,
  11. clippy::todo,
  12. reason = "TODO(LIB-16): Finalise this, then remove me"
  13. )]
  14. use core::cell::RefCell;
  15. use std::{io, rc::Rc};
  16. use anyhow::bail;
  17. use clap::Parser;
  18. use data_encoding::HEXLOWER;
  19. use libthreema::{
  20. cli::{FullIdentityConfig, FullIdentityConfigOptions},
  21. common::ClientInfo,
  22. csp::{
  23. CspProtocol, CspProtocolContext, CspProtocolInstruction, CspStateUpdate,
  24. payload::{IncomingPayload, MessageAck, MessageWithMetadataBox, OutgoingFrame, OutgoingPayload},
  25. },
  26. csp_e2e::{
  27. CspE2eProtocol, CspE2eProtocolContextInit,
  28. contacts::{
  29. create::{CreateContactsInstruction, CreateContactsResponse},
  30. lookup::ContactsLookupResponse,
  31. update::{UpdateContactsInstruction, UpdateContactsResponse},
  32. },
  33. message::task::incoming::{
  34. IncomingMessageInstruction, IncomingMessageLoop, IncomingMessageResponse, IncomingMessageTask,
  35. },
  36. reflect::{ReflectInstruction, ReflectResponse},
  37. transaction::{
  38. begin::{BeginTransactionInstruction, BeginTransactionResponse},
  39. commit::{CommitTransactionInstruction, CommitTransactionResponse},
  40. },
  41. },
  42. https::cli::https_client_builder,
  43. model::provider::in_memory::{DefaultShortcutProvider, InMemoryDb, InMemoryDbInit, InMemoryDbSettings},
  44. utils::logging::init_stderr_logging,
  45. };
  46. use tokio::{
  47. io::{AsyncReadExt as _, AsyncWriteExt as _},
  48. net::TcpStream,
  49. signal,
  50. sync::mpsc,
  51. };
  52. use tracing::{Level, debug, error, info, trace, warn};
  53. #[derive(Parser)]
  54. #[command()]
  55. struct CspE2eReceiveCommand {
  56. #[command(flatten)]
  57. config: FullIdentityConfigOptions,
  58. }
  59. enum IncomingPayloadForCspE2e {
  60. Message(MessageWithMetadataBox),
  61. MessageAck(MessageAck),
  62. }
  63. enum OutgoingPayloadForCspE2e {
  64. MessageAck(MessageAck),
  65. }
  66. impl From<OutgoingPayloadForCspE2e> for OutgoingPayload {
  67. fn from(payload: OutgoingPayloadForCspE2e) -> Self {
  68. match payload {
  69. OutgoingPayloadForCspE2e::MessageAck(message_ack) => OutgoingPayload::MessageAck(message_ack),
  70. }
  71. }
  72. }
  73. /// Payload queues for the main process
  74. struct PayloadQueuesForCspE2e {
  75. incoming: mpsc::Receiver<IncomingPayloadForCspE2e>,
  76. outgoing: mpsc::Sender<OutgoingPayloadForCspE2e>,
  77. }
  78. /// Payload queues for the protocol flow runner
  79. struct PayloadQueuesForCsp {
  80. incoming: mpsc::Sender<IncomingPayloadForCspE2e>,
  81. outgoing: mpsc::Receiver<OutgoingPayloadForCspE2e>,
  82. }
  83. struct CspProtocolRunner {
  84. /// The TCP stream
  85. stream: TcpStream,
  86. /// An instance of the [`CspProtocol`] state machine
  87. protocol: CspProtocol,
  88. }
  89. impl CspProtocolRunner {
  90. /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
  91. #[tracing::instrument(skip_all)]
  92. async fn new(
  93. server_address: Vec<(String, u16)>,
  94. context: CspProtocolContext,
  95. ) -> anyhow::Result<(Self, OutgoingFrame)> {
  96. // Connect via TCP
  97. debug!(?server_address, "Establishing TCP connection to chat server",);
  98. let tcp_stream = TcpStream::connect(
  99. server_address
  100. .first()
  101. .expect("CSP config should have at least one address"),
  102. )
  103. .await?;
  104. // Create the protocol
  105. let (csp_protocol, client_hello) = CspProtocol::new(context);
  106. Ok((
  107. Self {
  108. stream: tcp_stream,
  109. protocol: csp_protocol,
  110. },
  111. client_hello,
  112. ))
  113. }
  114. /// Do the handshake with the chat server by exchanging the following messages:
  115. ///
  116. /// ```txt
  117. /// C -- client-hello -> S
  118. /// C <- server-hello -- S
  119. /// C ---- login ---- -> S
  120. /// C <-- login-ack ---- S
  121. /// ```
  122. #[tracing::instrument(skip_all)]
  123. async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> anyhow::Result<()> {
  124. // Send the client hello
  125. debug!(length = client_hello.0.len(), "Sending client hello");
  126. self.send(&client_hello.0).await?;
  127. // Handshake by polling the CSP state
  128. for iteration in 1_usize.. {
  129. trace!("Iteration #{iteration}");
  130. // Receive required bytes and add them
  131. let bytes = self.receive_required().await?;
  132. self.protocol.add_chunks(&[&bytes])?;
  133. // Handle instruction
  134. let Some(instruction) = self.protocol.poll()? else {
  135. continue;
  136. };
  137. // We do not expect an incoming payload at this stage
  138. if let Some(incoming_payload) = instruction.incoming_payload {
  139. let message = "Unexpected incoming payload during handshake";
  140. error!(?incoming_payload, message);
  141. bail!(message)
  142. }
  143. // Send any outgoing frame
  144. if let Some(frame) = instruction.outgoing_frame {
  145. self.send(&frame.0).await?;
  146. }
  147. // Check if we've completed the handshake
  148. if let Some(CspStateUpdate::PostHandshake(login_ack_data)) = instruction.state_update {
  149. info!(?login_ack_data, "Handshake complete");
  150. break;
  151. }
  152. }
  153. Ok(())
  154. }
  155. /// Run the payload exchange flow until stopped.
  156. #[tracing::instrument(skip_all)]
  157. async fn run_payload_flow(&mut self, mut queues: PayloadQueuesForCsp) -> anyhow::Result<()> {
  158. let mut read_buffer = [0_u8; 8192];
  159. let mut next_instruction: Option<CspProtocolInstruction> = None;
  160. for iteration in 1_usize.. {
  161. trace!("Iteration #{iteration}");
  162. // Poll for an instruction, if necessary
  163. if next_instruction.is_none() {
  164. next_instruction = self.protocol.poll()?;
  165. }
  166. // Wait for more input, if necessary
  167. if next_instruction.is_none() {
  168. next_instruction = tokio::select! {
  169. // Forward any incoming chunks from the TCP stream
  170. _ = self.stream.readable() => {
  171. let length = self.try_receive(&mut read_buffer)?;
  172. // Add chunks (poll in the next iteration)
  173. self.protocol
  174. .add_chunks(&[read_buffer.get(..length)
  175. .expect("Amount of read bytes should be available")])?;
  176. None
  177. },
  178. // Forward any outgoing payloads
  179. outgoing_payload = queues.outgoing.recv() => {
  180. if let Some(outgoing_payload) = outgoing_payload {
  181. let outgoing_payload = OutgoingPayload::from(outgoing_payload);
  182. debug!(?outgoing_payload, "Sending payload");
  183. Some(self.protocol.create_payload(&outgoing_payload)?)
  184. } else {
  185. break
  186. }
  187. }
  188. };
  189. }
  190. // Handle instruction
  191. let Some(current_instruction) = next_instruction.take() else {
  192. continue;
  193. };
  194. // We do not expect any state updates at this stage
  195. if let Some(state_update) = current_instruction.state_update {
  196. let message = "Unexpected state update after handshake";
  197. error!(?state_update, message);
  198. bail!(message)
  199. }
  200. // Handle any incoming payload
  201. if let Some(incoming_payload) = current_instruction.incoming_payload {
  202. debug!(?incoming_payload, "Received payload");
  203. match incoming_payload {
  204. IncomingPayload::EchoRequest(echo_payload) => {
  205. // Respond to echo request
  206. next_instruction = Some(
  207. self.protocol
  208. .create_payload(&OutgoingPayload::EchoResponse(echo_payload))?,
  209. );
  210. },
  211. IncomingPayload::MessageWithMetadataBox(payload) => {
  212. // Forward message
  213. queues
  214. .incoming
  215. .send(IncomingPayloadForCspE2e::Message(payload))
  216. .await?;
  217. },
  218. IncomingPayload::MessageAck(payload) => {
  219. // Forward message ack
  220. queues
  221. .incoming
  222. .send(IncomingPayloadForCspE2e::MessageAck(payload))
  223. .await?;
  224. },
  225. IncomingPayload::EchoResponse(_)
  226. | IncomingPayload::QueueSendComplete
  227. | IncomingPayload::DeviceCookieChangeIndication
  228. | IncomingPayload::CloseError(_)
  229. | IncomingPayload::ServerAlert(_)
  230. | IncomingPayload::UnknownPayload { .. } => {},
  231. }
  232. }
  233. // Send any outgoing frame
  234. if let Some(frame) = current_instruction.outgoing_frame {
  235. self.send(&frame.0).await?;
  236. }
  237. }
  238. Ok(())
  239. }
  240. /// Shut down the TCP connection
  241. #[tracing::instrument(skip_all)]
  242. async fn shutdown(&mut self) -> anyhow::Result<()> {
  243. info!("Shutting down TCP connection");
  244. Ok(self.stream.shutdown().await?)
  245. }
  246. /// Send bytes to the server over the TCP connection
  247. #[tracing::instrument(skip_all, fields(bytes_length = bytes.len()))]
  248. async fn send(&mut self, bytes: &[u8]) -> anyhow::Result<()> {
  249. trace!(length = bytes.len(), "Sending bytes");
  250. self.stream.write_all(bytes).await?;
  251. Ok(())
  252. }
  253. #[tracing::instrument(skip_all)]
  254. async fn receive_required(&mut self) -> anyhow::Result<Vec<u8>> {
  255. // Get the minimum amount of bytes we'll need to receive
  256. let length = self.protocol.next_required_length()?;
  257. let mut buffer = vec![0; length];
  258. trace!(?length, "Reading bytes");
  259. // If there is nothing to read, return immediately
  260. if length == 0 {
  261. return Ok(buffer);
  262. }
  263. // Read the exact number of bytes required
  264. let _ = self.stream.read_exact(&mut buffer).await?;
  265. // Read more if available
  266. match self.stream.try_read_buf(&mut buffer) {
  267. Ok(0) => {
  268. // Remote shut down our reading end gracefully.
  269. //
  270. // IMPORTANT: An implementation needs to ensure that it stops gracefully by processing any
  271. // remaining payloads prior to stopping the protocol. This example implementation ensures this
  272. // by handling all pending instructions prior to polling for more data. The only case we bail
  273. // is therefore when our instruction queue is already dry.
  274. bail!("TCP reading end closed")
  275. },
  276. Ok(length) => {
  277. trace!(length, "Got additional bytes");
  278. },
  279. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  280. trace!("No additional bytes available");
  281. },
  282. Err(error) => {
  283. return Err(error.into());
  284. },
  285. }
  286. debug!(length = buffer.len(), "Received bytes");
  287. Ok(buffer)
  288. }
  289. #[tracing::instrument(skip_all)]
  290. fn try_receive(&mut self, buffer: &mut [u8]) -> anyhow::Result<usize> {
  291. match self.stream.try_read(buffer) {
  292. Ok(0) => {
  293. // Remote shut down our reading end. But we still need to process the previously
  294. // read bytes.
  295. warn!("TCP reading end closed");
  296. Ok(0)
  297. },
  298. Ok(length) => {
  299. debug!(length, "Received bytes");
  300. Ok(length)
  301. },
  302. Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
  303. trace!("No bytes to receive");
  304. Ok(0)
  305. },
  306. Err(error) => Err(error.into()),
  307. }
  308. }
  309. }
  310. struct CspE2eProtocolRunner {
  311. /// An instance of the [`CspE2eProtocol`] state machine
  312. protocol: CspE2eProtocol,
  313. /// HTTP client
  314. http_client: reqwest::Client,
  315. }
  316. impl CspE2eProtocolRunner {
  317. #[tracing::instrument(skip_all)]
  318. fn new(http_client: reqwest::Client, context: CspE2eProtocolContextInit) -> anyhow::Result<Self> {
  319. Ok(Self {
  320. protocol: CspE2eProtocol::new(context),
  321. http_client,
  322. })
  323. }
  324. /// Run the receive flow until stopped.
  325. #[tracing::instrument(skip_all)]
  326. async fn run_receive_flow(&mut self, mut queues: PayloadQueuesForCspE2e) -> anyhow::Result<()> {
  327. let mut pending_task: Option<IncomingMessageTask> = None;
  328. for iteration in 1_usize.. {
  329. trace!("Receive flow iteration #{iteration}");
  330. // Handle any incoming payloads until we have a task
  331. let Some(task) = &mut pending_task else {
  332. match queues.incoming.recv().await {
  333. Some(IncomingPayloadForCspE2e::Message(message)) => {
  334. trace!(message = HEXLOWER.encode(&message.bytes), "Raw incoming message");
  335. info!(?message, "Incoming message");
  336. pending_task = Some(self.protocol.handle_incoming_message(message));
  337. },
  338. Some(IncomingPayloadForCspE2e::MessageAck(message_ack)) => {
  339. warn!(?message_ack, "Unexpected message-ack");
  340. },
  341. None => {},
  342. }
  343. continue;
  344. };
  345. // Handle task
  346. match task.poll(self.protocol.context())? {
  347. IncomingMessageLoop::Instruction(IncomingMessageInstruction::FetchSender(instruction)) => {
  348. // Run both requests simultaneously
  349. let work_directory_request_future = async {
  350. match instruction.work_directory_request {
  351. Some(work_directory_request) => {
  352. work_directory_request.send(&self.http_client).await.map(Some)
  353. },
  354. None => Ok(None),
  355. }
  356. };
  357. let (directory_result, work_directory_result) = tokio::join!(
  358. instruction.directory_request.send(&self.http_client),
  359. work_directory_request_future,
  360. );
  361. // Forward response
  362. task.response(IncomingMessageResponse::FetchSender(ContactsLookupResponse {
  363. directory_result,
  364. work_directory_result: work_directory_result.transpose(),
  365. }))?;
  366. },
  367. IncomingMessageLoop::Instruction(IncomingMessageInstruction::CreateContact(instruction)) => {
  368. match instruction {
  369. CreateContactsInstruction::BeginTransaction(instruction) => {
  370. // Begin transaction and forward response, if any
  371. let response = self.begin_transaction(instruction).await?;
  372. if let Some(response) = response {
  373. task.response(IncomingMessageResponse::CreateContact(
  374. CreateContactsResponse::BeginTransactionResponse(response),
  375. ))?;
  376. }
  377. },
  378. CreateContactsInstruction::ReflectAndCommitTransaction(instruction) => {
  379. // Reflect and commit transaction and forward response
  380. task.response(IncomingMessageResponse::CreateContact(
  381. CreateContactsResponse::CommitTransactionResponse(
  382. self.reflect_and_commit_transaction(instruction).await?,
  383. ),
  384. ))?;
  385. },
  386. }
  387. },
  388. IncomingMessageLoop::Instruction(IncomingMessageInstruction::UpdateContact(instruction)) => {
  389. match instruction {
  390. UpdateContactsInstruction::BeginTransaction(instruction) => {
  391. // Begin transaction and forward response, if any
  392. let response = self.begin_transaction(instruction).await?;
  393. if let Some(response) = response {
  394. task.response(IncomingMessageResponse::UpdateContact(
  395. UpdateContactsResponse::BeginTransactionResponse(response),
  396. ))?;
  397. }
  398. },
  399. UpdateContactsInstruction::ReflectAndCommitTransaction(instruction) => {
  400. // Reflect and commit transaction and forward response
  401. task.response(IncomingMessageResponse::UpdateContact(
  402. UpdateContactsResponse::CommitTransactionResponse(
  403. self.reflect_and_commit_transaction(instruction).await?,
  404. ),
  405. ))?;
  406. },
  407. }
  408. },
  409. IncomingMessageLoop::Instruction(IncomingMessageInstruction::ReflectMessage(instruction)) => {
  410. task.response(IncomingMessageResponse::ReflectMessage(
  411. self.reflect(instruction).await?,
  412. ))?;
  413. },
  414. IncomingMessageLoop::Done(result) => {
  415. // Send message acknowledgement, if any
  416. if let Some(outgoing_message_ack) = result.outgoing_message_ack {
  417. queues
  418. .outgoing
  419. .send(OutgoingPayloadForCspE2e::MessageAck(outgoing_message_ack))
  420. .await?;
  421. }
  422. pending_task = None;
  423. // TODO(LIB-16). Enqueue outgoing message task, if any
  424. },
  425. }
  426. }
  427. Ok(())
  428. }
  429. #[tracing::instrument(skip_all)]
  430. async fn begin_transaction(
  431. &self,
  432. instruction: BeginTransactionInstruction,
  433. ) -> anyhow::Result<Option<BeginTransactionResponse>> {
  434. match instruction {
  435. BeginTransactionInstruction::TransactionRejected => {
  436. // TODO(LIB-16). Await TransactionEnded
  437. Ok(None)
  438. },
  439. BeginTransactionInstruction::BeginTransaction { message } => {
  440. // TODO(LIB-16). Send `BeginTransaction, await BeginTransactionAck or TransactionRejected,
  441. // then return BeginTransactionResponse(message)
  442. Ok(Some(BeginTransactionResponse::BeginTransactionReply(todo!())))
  443. },
  444. BeginTransactionInstruction::AbortTransaction { message } => {
  445. // TODO(LIB-16). Send `CommitTransaction`, await CommitTransactionAck, then return
  446. // AbortTransaction(CommitTransactionAck)
  447. Ok(Some(BeginTransactionResponse::AbortTransactionResponse(todo!())))
  448. },
  449. }
  450. }
  451. #[tracing::instrument(skip_all)]
  452. async fn reflect_and_commit_transaction(
  453. &self,
  454. instruction: CommitTransactionInstruction,
  455. ) -> anyhow::Result<CommitTransactionResponse> {
  456. // TODO(LIB-16). Reflect messages, then immediately commit. Await CommitAck and gather any
  457. // reflect-acks
  458. Ok(CommitTransactionResponse {
  459. acknowledged_reflect_ids: todo!(),
  460. commit_transaction_ack: todo!(),
  461. })
  462. }
  463. #[tracing::instrument(skip_all)]
  464. async fn reflect(&self, instruction: ReflectInstruction) -> anyhow::Result<ReflectResponse> {
  465. // TODO(LIB-16). Reflect messages, then wait for corresponding reflect-acks
  466. Ok(ReflectResponse {
  467. acknowledged_reflect_ids: todo!(),
  468. })
  469. }
  470. }
  471. #[tokio::main]
  472. async fn main() -> anyhow::Result<()> {
  473. // Configure logging
  474. init_stderr_logging(Level::TRACE);
  475. // Create HTTP client
  476. let http_client = https_client_builder().build()?;
  477. // Parse arguments for command
  478. let arguments = CspE2eReceiveCommand::parse();
  479. let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
  480. // Create CSP E2EE protocol context
  481. let mut database = InMemoryDb::from(InMemoryDbInit {
  482. user_identity: config.minimal.user_identity,
  483. settings: InMemoryDbSettings {
  484. block_unknown_identities: false,
  485. },
  486. contacts: vec![],
  487. blocked_identities: vec![],
  488. });
  489. let csp_e2e_context = CspE2eProtocolContextInit {
  490. client_info: ClientInfo::Libthreema,
  491. config: Rc::clone(&config.minimal.common.config),
  492. csp_e2e: config.csp_e2e_context_init(Box::new(RefCell::new(database.csp_e2e_nonce_provider()))),
  493. d2x: config.d2x_context_init(Box::new(RefCell::new(database.d2d_nonce_provider()))),
  494. shortcut: Box::new(DefaultShortcutProvider),
  495. settings: Box::new(RefCell::new(database.settings_provider())),
  496. contacts: Box::new(RefCell::new(database.contact_provider())),
  497. conversations: Box::new(RefCell::new(database.message_provider())),
  498. };
  499. // Create payload queues
  500. let (csp_e2e_queues, csp_queues) = {
  501. let incoming_payload = mpsc::channel(4);
  502. let outgoing_payload = mpsc::channel(4);
  503. (
  504. PayloadQueuesForCspE2e {
  505. incoming: incoming_payload.1,
  506. outgoing: outgoing_payload.0,
  507. },
  508. PayloadQueuesForCsp {
  509. incoming: incoming_payload.0,
  510. outgoing: outgoing_payload.1,
  511. },
  512. )
  513. };
  514. // Create CSP protocol and establish a connection
  515. let (mut csp_runner, client_hello) = CspProtocolRunner::new(
  516. config
  517. .minimal
  518. .common
  519. .config
  520. .chat_server_address
  521. .addresses(config.csp_server_group),
  522. config
  523. .csp_context_init()
  524. .try_into()
  525. .expect("Configuration should be valid"),
  526. )
  527. .await?;
  528. // Run the handshake flow
  529. csp_runner.run_handshake_flow(client_hello).await?;
  530. // Create CSP E2E protocol
  531. let mut csp_e2e_protocol = CspE2eProtocolRunner::new(http_client, csp_e2e_context)?;
  532. // Run the protocols
  533. tokio::select! {
  534. _ = csp_runner.run_payload_flow(csp_queues) => {},
  535. _ = csp_e2e_protocol.run_receive_flow(csp_e2e_queues) => {},
  536. _ = signal::ctrl_c() => {},
  537. };
  538. // Shut down
  539. csp_runner.shutdown().await?;
  540. Ok(())
  541. }
  542. #[test]
  543. fn verify_cli() {
  544. use clap::CommandFactory;
  545. CspE2eReceiveCommand::command().debug_assert();
  546. }