| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608 |
- //! Example for usage of the Chat Server E2EE Protocol, connecting to the chat server and receiving incoming
- //! messages.
- #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
- #![expect(
- clippy::integer_division_remainder_used,
- reason = "Some internal of tokio::select triggers this"
- )]
- #![expect(
- unreachable_code,
- unused_variables,
- clippy::todo,
- reason = "TODO(LIB-16): Finalise this, then remove me"
- )]
- use core::cell::RefCell;
- use std::{io, rc::Rc};
- use anyhow::bail;
- use clap::Parser;
- use data_encoding::HEXLOWER;
- use libthreema::{
- cli::{FullIdentityConfig, FullIdentityConfigOptions},
- common::ClientInfo,
- csp::{
- CspProtocol, CspProtocolContext, CspProtocolInstruction, CspStateUpdate,
- payload::{IncomingPayload, MessageAck, MessageWithMetadataBox, OutgoingFrame, OutgoingPayload},
- },
- csp_e2e::{
- CspE2eProtocol, CspE2eProtocolContextInit,
- contacts::{
- create::{CreateContactsInstruction, CreateContactsResponse},
- lookup::ContactsLookupResponse,
- update::{UpdateContactsInstruction, UpdateContactsResponse},
- },
- message::task::incoming::{
- IncomingMessageInstruction, IncomingMessageLoop, IncomingMessageResponse, IncomingMessageTask,
- },
- reflect::{ReflectInstruction, ReflectResponse},
- transaction::{
- begin::{BeginTransactionInstruction, BeginTransactionResponse},
- commit::{CommitTransactionInstruction, CommitTransactionResponse},
- },
- },
- https::cli::https_client_builder,
- model::provider::in_memory::{DefaultShortcutProvider, InMemoryDb, InMemoryDbInit, InMemoryDbSettings},
- utils::logging::init_stderr_logging,
- };
- use tokio::{
- io::{AsyncReadExt as _, AsyncWriteExt as _},
- net::TcpStream,
- signal,
- sync::mpsc,
- };
- use tracing::{Level, debug, error, info, trace, warn};
- #[derive(Parser)]
- #[command()]
- struct CspE2eReceiveCommand {
- #[command(flatten)]
- config: FullIdentityConfigOptions,
- }
- enum IncomingPayloadForCspE2e {
- Message(MessageWithMetadataBox),
- MessageAck(MessageAck),
- }
- enum OutgoingPayloadForCspE2e {
- MessageAck(MessageAck),
- }
- impl From<OutgoingPayloadForCspE2e> for OutgoingPayload {
- fn from(payload: OutgoingPayloadForCspE2e) -> Self {
- match payload {
- OutgoingPayloadForCspE2e::MessageAck(message_ack) => OutgoingPayload::MessageAck(message_ack),
- }
- }
- }
- /// Payload queues for the main process
- struct PayloadQueuesForCspE2e {
- incoming: mpsc::Receiver<IncomingPayloadForCspE2e>,
- outgoing: mpsc::Sender<OutgoingPayloadForCspE2e>,
- }
- /// Payload queues for the protocol flow runner
- struct PayloadQueuesForCsp {
- incoming: mpsc::Sender<IncomingPayloadForCspE2e>,
- outgoing: mpsc::Receiver<OutgoingPayloadForCspE2e>,
- }
- struct CspProtocolRunner {
- /// The TCP stream
- stream: TcpStream,
- /// An instance of the [`CspProtocol`] state machine
- protocol: CspProtocol,
- }
- impl CspProtocolRunner {
- /// Initiate a CSP protocol connection and hand out the initial `client_hello` message
- #[tracing::instrument(skip_all)]
- async fn new(
- server_address: Vec<(String, u16)>,
- context: CspProtocolContext,
- ) -> anyhow::Result<(Self, OutgoingFrame)> {
- // Connect via TCP
- debug!(?server_address, "Establishing TCP connection to chat server",);
- let tcp_stream = TcpStream::connect(
- server_address
- .first()
- .expect("CSP config should have at least one address"),
- )
- .await?;
- // Create the protocol
- let (csp_protocol, client_hello) = CspProtocol::new(context);
- Ok((
- Self {
- stream: tcp_stream,
- protocol: csp_protocol,
- },
- client_hello,
- ))
- }
- /// Do the handshake with the chat server by exchanging the following messages:
- ///
- /// ```txt
- /// C -- client-hello -> S
- /// C <- server-hello -- S
- /// C ---- login ---- -> S
- /// C <-- login-ack ---- S
- /// ```
- #[tracing::instrument(skip_all)]
- async fn run_handshake_flow(&mut self, client_hello: OutgoingFrame) -> anyhow::Result<()> {
- // Send the client hello
- debug!(length = client_hello.0.len(), "Sending client hello");
- self.send(&client_hello.0).await?;
- // Handshake by polling the CSP state
- for iteration in 1_usize.. {
- trace!("Iteration #{iteration}");
- // Receive required bytes and add them
- let bytes = self.receive_required().await?;
- self.protocol.add_chunks(&[&bytes])?;
- // Handle instruction
- let Some(instruction) = self.protocol.poll()? else {
- continue;
- };
- // We do not expect an incoming payload at this stage
- if let Some(incoming_payload) = instruction.incoming_payload {
- let message = "Unexpected incoming payload during handshake";
- error!(?incoming_payload, message);
- bail!(message)
- }
- // Send any outgoing frame
- if let Some(frame) = instruction.outgoing_frame {
- self.send(&frame.0).await?;
- }
- // Check if we've completed the handshake
- if let Some(CspStateUpdate::PostHandshake(login_ack_data)) = instruction.state_update {
- info!(?login_ack_data, "Handshake complete");
- break;
- }
- }
- Ok(())
- }
- /// Run the payload exchange flow until stopped.
- #[tracing::instrument(skip_all)]
- async fn run_payload_flow(&mut self, mut queues: PayloadQueuesForCsp) -> anyhow::Result<()> {
- let mut read_buffer = [0_u8; 8192];
- let mut next_instruction: Option<CspProtocolInstruction> = None;
- for iteration in 1_usize.. {
- trace!("Iteration #{iteration}");
- // Poll for an instruction, if necessary
- if next_instruction.is_none() {
- next_instruction = self.protocol.poll()?;
- }
- // Wait for more input, if necessary
- if next_instruction.is_none() {
- next_instruction = tokio::select! {
- // Forward any incoming chunks from the TCP stream
- _ = self.stream.readable() => {
- let length = self.try_receive(&mut read_buffer)?;
- // Add chunks (poll in the next iteration)
- self.protocol
- .add_chunks(&[read_buffer.get(..length)
- .expect("Amount of read bytes should be available")])?;
- None
- },
- // Forward any outgoing payloads
- outgoing_payload = queues.outgoing.recv() => {
- if let Some(outgoing_payload) = outgoing_payload {
- let outgoing_payload = OutgoingPayload::from(outgoing_payload);
- debug!(?outgoing_payload, "Sending payload");
- Some(self.protocol.create_payload(&outgoing_payload)?)
- } else {
- break
- }
- }
- };
- }
- // Handle instruction
- let Some(current_instruction) = next_instruction.take() else {
- continue;
- };
- // We do not expect any state updates at this stage
- if let Some(state_update) = current_instruction.state_update {
- let message = "Unexpected state update after handshake";
- error!(?state_update, message);
- bail!(message)
- }
- // Handle any incoming payload
- if let Some(incoming_payload) = current_instruction.incoming_payload {
- debug!(?incoming_payload, "Received payload");
- match incoming_payload {
- IncomingPayload::EchoRequest(echo_payload) => {
- // Respond to echo request
- next_instruction = Some(
- self.protocol
- .create_payload(&OutgoingPayload::EchoResponse(echo_payload))?,
- );
- },
- IncomingPayload::MessageWithMetadataBox(payload) => {
- // Forward message
- queues
- .incoming
- .send(IncomingPayloadForCspE2e::Message(payload))
- .await?;
- },
- IncomingPayload::MessageAck(payload) => {
- // Forward message ack
- queues
- .incoming
- .send(IncomingPayloadForCspE2e::MessageAck(payload))
- .await?;
- },
- IncomingPayload::EchoResponse(_)
- | IncomingPayload::QueueSendComplete
- | IncomingPayload::DeviceCookieChangeIndication
- | IncomingPayload::CloseError(_)
- | IncomingPayload::ServerAlert(_)
- | IncomingPayload::UnknownPayload { .. } => {},
- }
- }
- // Send any outgoing frame
- if let Some(frame) = current_instruction.outgoing_frame {
- self.send(&frame.0).await?;
- }
- }
- Ok(())
- }
- /// Shut down the TCP connection
- #[tracing::instrument(skip_all)]
- async fn shutdown(&mut self) -> anyhow::Result<()> {
- info!("Shutting down TCP connection");
- Ok(self.stream.shutdown().await?)
- }
- /// Send bytes to the server over the TCP connection
- #[tracing::instrument(skip_all, fields(bytes_length = bytes.len()))]
- async fn send(&mut self, bytes: &[u8]) -> anyhow::Result<()> {
- trace!(length = bytes.len(), "Sending bytes");
- self.stream.write_all(bytes).await?;
- Ok(())
- }
- #[tracing::instrument(skip_all)]
- async fn receive_required(&mut self) -> anyhow::Result<Vec<u8>> {
- // Get the minimum amount of bytes we'll need to receive
- let length = self.protocol.next_required_length()?;
- let mut buffer = vec![0; length];
- trace!(?length, "Reading bytes");
- // If there is nothing to read, return immediately
- if length == 0 {
- return Ok(buffer);
- }
- // Read the exact number of bytes required
- let _ = self.stream.read_exact(&mut buffer).await?;
- // Read more if available
- match self.stream.try_read_buf(&mut buffer) {
- Ok(0) => {
- // Remote shut down our reading end gracefully.
- //
- // IMPORTANT: An implementation needs to ensure that it stops gracefully by processing any
- // remaining payloads prior to stopping the protocol. This example implementation ensures this
- // by handling all pending instructions prior to polling for more data. The only case we bail
- // is therefore when our instruction queue is already dry.
- bail!("TCP reading end closed")
- },
- Ok(length) => {
- trace!(length, "Got additional bytes");
- },
- Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
- trace!("No additional bytes available");
- },
- Err(error) => {
- return Err(error.into());
- },
- }
- debug!(length = buffer.len(), "Received bytes");
- Ok(buffer)
- }
- #[tracing::instrument(skip_all)]
- fn try_receive(&mut self, buffer: &mut [u8]) -> anyhow::Result<usize> {
- match self.stream.try_read(buffer) {
- Ok(0) => {
- // Remote shut down our reading end. But we still need to process the previously
- // read bytes.
- warn!("TCP reading end closed");
- Ok(0)
- },
- Ok(length) => {
- debug!(length, "Received bytes");
- Ok(length)
- },
- Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
- trace!("No bytes to receive");
- Ok(0)
- },
- Err(error) => Err(error.into()),
- }
- }
- }
- struct CspE2eProtocolRunner {
- /// An instance of the [`CspE2eProtocol`] state machine
- protocol: CspE2eProtocol,
- /// HTTP client
- http_client: reqwest::Client,
- }
- impl CspE2eProtocolRunner {
- #[tracing::instrument(skip_all)]
- fn new(http_client: reqwest::Client, context: CspE2eProtocolContextInit) -> anyhow::Result<Self> {
- Ok(Self {
- protocol: CspE2eProtocol::new(context),
- http_client,
- })
- }
- /// Run the receive flow until stopped.
- #[tracing::instrument(skip_all)]
- async fn run_receive_flow(&mut self, mut queues: PayloadQueuesForCspE2e) -> anyhow::Result<()> {
- let mut pending_task: Option<IncomingMessageTask> = None;
- for iteration in 1_usize.. {
- trace!("Receive flow iteration #{iteration}");
- // Handle any incoming payloads until we have a task
- let Some(task) = &mut pending_task else {
- match queues.incoming.recv().await {
- Some(IncomingPayloadForCspE2e::Message(message)) => {
- trace!(message = HEXLOWER.encode(&message.bytes), "Raw incoming message");
- info!(?message, "Incoming message");
- pending_task = Some(self.protocol.handle_incoming_message(message));
- },
- Some(IncomingPayloadForCspE2e::MessageAck(message_ack)) => {
- warn!(?message_ack, "Unexpected message-ack");
- },
- None => {},
- }
- continue;
- };
- // Handle task
- match task.poll(self.protocol.context())? {
- IncomingMessageLoop::Instruction(IncomingMessageInstruction::FetchSender(instruction)) => {
- // Run both requests simultaneously
- let work_directory_request_future = async {
- match instruction.work_directory_request {
- Some(work_directory_request) => {
- work_directory_request.send(&self.http_client).await.map(Some)
- },
- None => Ok(None),
- }
- };
- let (directory_result, work_directory_result) = tokio::join!(
- instruction.directory_request.send(&self.http_client),
- work_directory_request_future,
- );
- // Forward response
- task.response(IncomingMessageResponse::FetchSender(ContactsLookupResponse {
- directory_result,
- work_directory_result: work_directory_result.transpose(),
- }))?;
- },
- IncomingMessageLoop::Instruction(IncomingMessageInstruction::CreateContact(instruction)) => {
- match instruction {
- CreateContactsInstruction::BeginTransaction(instruction) => {
- // Begin transaction and forward response, if any
- let response = self.begin_transaction(instruction).await?;
- if let Some(response) = response {
- task.response(IncomingMessageResponse::CreateContact(
- CreateContactsResponse::BeginTransactionResponse(response),
- ))?;
- }
- },
- CreateContactsInstruction::ReflectAndCommitTransaction(instruction) => {
- // Reflect and commit transaction and forward response
- task.response(IncomingMessageResponse::CreateContact(
- CreateContactsResponse::CommitTransactionResponse(
- self.reflect_and_commit_transaction(instruction).await?,
- ),
- ))?;
- },
- }
- },
- IncomingMessageLoop::Instruction(IncomingMessageInstruction::UpdateContact(instruction)) => {
- match instruction {
- UpdateContactsInstruction::BeginTransaction(instruction) => {
- // Begin transaction and forward response, if any
- let response = self.begin_transaction(instruction).await?;
- if let Some(response) = response {
- task.response(IncomingMessageResponse::UpdateContact(
- UpdateContactsResponse::BeginTransactionResponse(response),
- ))?;
- }
- },
- UpdateContactsInstruction::ReflectAndCommitTransaction(instruction) => {
- // Reflect and commit transaction and forward response
- task.response(IncomingMessageResponse::UpdateContact(
- UpdateContactsResponse::CommitTransactionResponse(
- self.reflect_and_commit_transaction(instruction).await?,
- ),
- ))?;
- },
- }
- },
- IncomingMessageLoop::Instruction(IncomingMessageInstruction::ReflectMessage(instruction)) => {
- task.response(IncomingMessageResponse::ReflectMessage(
- self.reflect(instruction).await?,
- ))?;
- },
- IncomingMessageLoop::Done(result) => {
- // Send message acknowledgement, if any
- if let Some(outgoing_message_ack) = result.outgoing_message_ack {
- queues
- .outgoing
- .send(OutgoingPayloadForCspE2e::MessageAck(outgoing_message_ack))
- .await?;
- }
- pending_task = None;
- // TODO(LIB-16). Enqueue outgoing message task, if any
- },
- }
- }
- Ok(())
- }
- #[tracing::instrument(skip_all)]
- async fn begin_transaction(
- &self,
- instruction: BeginTransactionInstruction,
- ) -> anyhow::Result<Option<BeginTransactionResponse>> {
- match instruction {
- BeginTransactionInstruction::TransactionRejected => {
- // TODO(LIB-16). Await TransactionEnded
- Ok(None)
- },
- BeginTransactionInstruction::BeginTransaction { message } => {
- // TODO(LIB-16). Send `BeginTransaction, await BeginTransactionAck or TransactionRejected,
- // then return BeginTransactionResponse(message)
- Ok(Some(BeginTransactionResponse::BeginTransactionReply(todo!())))
- },
- BeginTransactionInstruction::AbortTransaction { message } => {
- // TODO(LIB-16). Send `CommitTransaction`, await CommitTransactionAck, then return
- // AbortTransaction(CommitTransactionAck)
- Ok(Some(BeginTransactionResponse::AbortTransactionResponse(todo!())))
- },
- }
- }
- #[tracing::instrument(skip_all)]
- async fn reflect_and_commit_transaction(
- &self,
- instruction: CommitTransactionInstruction,
- ) -> anyhow::Result<CommitTransactionResponse> {
- // TODO(LIB-16). Reflect messages, then immediately commit. Await CommitAck and gather any
- // reflect-acks
- Ok(CommitTransactionResponse {
- acknowledged_reflect_ids: todo!(),
- commit_transaction_ack: todo!(),
- })
- }
- #[tracing::instrument(skip_all)]
- async fn reflect(&self, instruction: ReflectInstruction) -> anyhow::Result<ReflectResponse> {
- // TODO(LIB-16). Reflect messages, then wait for corresponding reflect-acks
- Ok(ReflectResponse {
- acknowledged_reflect_ids: todo!(),
- })
- }
- }
- #[tokio::main]
- async fn main() -> anyhow::Result<()> {
- // Configure logging
- init_stderr_logging(Level::TRACE);
- // Create HTTP client
- let http_client = https_client_builder().build()?;
- // Parse arguments for command
- let arguments = CspE2eReceiveCommand::parse();
- let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
- // Create CSP E2EE protocol context
- let mut database = InMemoryDb::from(InMemoryDbInit {
- user_identity: config.minimal.user_identity,
- settings: InMemoryDbSettings {
- block_unknown_identities: false,
- },
- contacts: vec![],
- blocked_identities: vec![],
- });
- let csp_e2e_context = CspE2eProtocolContextInit {
- client_info: ClientInfo::Libthreema,
- config: Rc::clone(&config.minimal.common.config),
- csp_e2e: config.csp_e2e_context_init(Box::new(RefCell::new(database.csp_e2e_nonce_provider()))),
- d2x: config.d2x_context_init(Box::new(RefCell::new(database.d2d_nonce_provider()))),
- shortcut: Box::new(DefaultShortcutProvider),
- settings: Box::new(RefCell::new(database.settings_provider())),
- contacts: Box::new(RefCell::new(database.contact_provider())),
- conversations: Box::new(RefCell::new(database.message_provider())),
- };
- // Create payload queues
- let (csp_e2e_queues, csp_queues) = {
- let incoming_payload = mpsc::channel(4);
- let outgoing_payload = mpsc::channel(4);
- (
- PayloadQueuesForCspE2e {
- incoming: incoming_payload.1,
- outgoing: outgoing_payload.0,
- },
- PayloadQueuesForCsp {
- incoming: incoming_payload.0,
- outgoing: outgoing_payload.1,
- },
- )
- };
- // Create CSP protocol and establish a connection
- let (mut csp_runner, client_hello) = CspProtocolRunner::new(
- config
- .minimal
- .common
- .config
- .chat_server_address
- .addresses(config.csp_server_group),
- config
- .csp_context_init()
- .try_into()
- .expect("Configuration should be valid"),
- )
- .await?;
- // Run the handshake flow
- csp_runner.run_handshake_flow(client_hello).await?;
- // Create CSP E2E protocol
- let mut csp_e2e_protocol = CspE2eProtocolRunner::new(http_client, csp_e2e_context)?;
- // Run the protocols
- tokio::select! {
- _ = csp_runner.run_payload_flow(csp_queues) => {},
- _ = csp_e2e_protocol.run_receive_flow(csp_e2e_queues) => {},
- _ = signal::ctrl_c() => {},
- };
- // Shut down
- csp_runner.shutdown().await?;
- Ok(())
- }
- #[test]
- fn verify_cli() {
- use clap::CommandFactory;
- CspE2eReceiveCommand::command().debug_assert();
- }
|