d2m_ping_pong.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. //! Example for usage of the Device to Mediator Protocol state machine, doing a real handshake with the
  2. //! mediator server and an exemplary payload flow loop.
  3. #![expect(unused_crate_dependencies, reason = "Example triggered false positive")]
  4. #![expect(
  5. clippy::integer_division_remainder_used,
  6. reason = "Some internal of tokio::select triggers this"
  7. )]
  8. use core::time::Duration;
  9. use anyhow::{Result, anyhow, bail};
  10. use clap::Parser;
  11. use futures_util::{SinkExt as _, TryStreamExt as _};
  12. use libthreema::{
  13. cli::{FullIdentityConfig, FullIdentityConfigOptions},
  14. d2m::{
  15. D2mContext, D2mProtocol, D2mStateUpdate,
  16. payload::{BeginTransaction, IncomingPayload, OutgoingPayload, Reflect, ReflectFlags},
  17. },
  18. https::cli::https_client_builder,
  19. utils::logging::init_stderr_logging,
  20. };
  21. use rand::random;
  22. use reqwest::StatusCode;
  23. use tokio::{
  24. net::TcpStream,
  25. signal,
  26. sync::mpsc,
  27. time::{self, Instant},
  28. };
  29. use tokio_tungstenite::{
  30. MaybeTlsStream, WebSocketStream, connect_async,
  31. tungstenite::protocol::{CloseFrame, Message, frame::coding::CloseCode},
  32. };
  33. use tracing::{Level, debug, error, info, trace, warn};
  34. #[derive(Parser)]
  35. #[command()]
  36. struct D2mPingPongCommand {
  37. #[command(flatten)]
  38. config: FullIdentityConfigOptions,
  39. }
  40. /// Payload queues for the main process
  41. struct PayloadQueuesForD2mPingPong {
  42. incoming: mpsc::Receiver<IncomingPayload>,
  43. outgoing: mpsc::Sender<OutgoingPayload>,
  44. }
  45. /// Payload queues for the protocol flow runner
  46. struct PayloadQueuesForProtocol {
  47. incoming: mpsc::Sender<IncomingPayload>,
  48. outgoing: mpsc::Receiver<OutgoingPayload>,
  49. }
  50. struct D2mProtocolRunner {
  51. /// The WebSocket stream
  52. stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
  53. /// An instance of the [`D2mProtocol`] state machine
  54. protocol: D2mProtocol,
  55. }
  56. impl D2mProtocolRunner {
  57. /// Initiate a D2M protocol connection
  58. #[tracing::instrument(skip_all)]
  59. async fn new(context: D2mContext) -> Result<Self> {
  60. // Create the protocol
  61. let (d2m_protocol, url) = D2mProtocol::new(context);
  62. // Connect via WebSocket
  63. debug!(?url, "Establishing WebSocket connection to mediator server");
  64. let (stream, response) = connect_async(url).await?;
  65. if response.status() != StatusCode::SWITCHING_PROTOCOLS {
  66. bail!(
  67. "Expected response to switch protocols ({expected}), got {actual}",
  68. expected = StatusCode::SWITCHING_PROTOCOLS,
  69. actual = response.status(),
  70. );
  71. }
  72. Ok(Self {
  73. stream,
  74. protocol: d2m_protocol,
  75. })
  76. }
  77. /// Do the handshake with the mediator server by exchanging the following messages:
  78. ///
  79. /// ```txt
  80. /// C -- client-info --> S (was already sent as part of the URL's path)
  81. /// C <- server-hello -- S
  82. /// C -- client-hello -> S
  83. /// C <- server-info --- S
  84. /// ```
  85. async fn run_handshake_flow(&mut self) -> Result<()> {
  86. for iteration in 1_usize.. {
  87. trace!("Iteration #{iteration}");
  88. // Receive datagram and add it
  89. let datagram = self.receive().await?;
  90. self.protocol.add_datagrams(vec![datagram])?;
  91. // Handle instruction
  92. let Some(instruction) = self.protocol.poll()? else {
  93. continue;
  94. };
  95. // We do not expect an incoming payload at this stage
  96. if let Some(incoming_payload) = instruction.incoming_payload {
  97. let message = "Unexpected incoming payload during handshake";
  98. error!(?incoming_payload, message);
  99. bail!(message)
  100. }
  101. // Send any outgoing datagram
  102. if let Some(datagram) = instruction.outgoing_datagram {
  103. self.send(datagram.0).await?;
  104. }
  105. // Check if we've completed the handshake
  106. if let Some(D2mStateUpdate::PostHandshake(server_info)) = instruction.state_update {
  107. info!(?server_info, "Handshake completed");
  108. break;
  109. }
  110. }
  111. Ok(())
  112. }
  113. /// Run the payload exchange flow until stopped.
  114. #[tracing::instrument(skip_all)]
  115. async fn run_payload_flow(&mut self, mut queues: PayloadQueuesForProtocol) -> Result<()> {
  116. for iteration in 1_usize.. {
  117. trace!("Payload flow iteration #{iteration}");
  118. // Poll for any pending instruction
  119. let mut instruction = self.protocol.poll()?;
  120. if instruction.is_none() {
  121. // No pending instruction left, wait for more input
  122. instruction = tokio::select! {
  123. // Forward any incoming datagrams from the WebSocket transport
  124. datagram = self.receive() => {
  125. // Add datagram (poll in the next iteration)
  126. self.protocol.add_datagrams(vec![datagram?])?;
  127. None
  128. },
  129. // Forward any outgoing payloads
  130. Some(outgoing_payload) = queues.outgoing.recv() => {
  131. debug!(?outgoing_payload, "Sending payload");
  132. let instruction = self.protocol.create_payload(outgoing_payload)?;
  133. Some(instruction)
  134. }
  135. }
  136. }
  137. let Some(instruction) = instruction else {
  138. continue;
  139. };
  140. // We do not expect any state updates at this stage
  141. if let Some(state_update) = instruction.state_update {
  142. let message = "Unexpected state update after handshake";
  143. error!(?state_update, message);
  144. bail!(message)
  145. }
  146. // Log any incoming payload
  147. if let Some(incoming_payload) = instruction.incoming_payload {
  148. debug!(?incoming_payload, "Received payload");
  149. queues.incoming.send(incoming_payload).await?;
  150. }
  151. // Send any outgoing datagram
  152. if let Some(datagram) = instruction.outgoing_datagram {
  153. self.send(datagram.0).await?;
  154. }
  155. }
  156. Ok(())
  157. }
  158. #[tracing::instrument(skip_all)]
  159. async fn shutdown(mut self) -> Result<()> {
  160. info!("Shutting down WebSocket connection");
  161. // Normal closure, e.g. when the user is explicitly disconnecting
  162. Ok(self
  163. .stream
  164. .close(Some(CloseFrame {
  165. code: CloseCode::Normal,
  166. reason: "Bye".into(),
  167. }))
  168. .await?)
  169. }
  170. #[tracing::instrument(skip_all)]
  171. async fn send(&mut self, datagram: Vec<u8>) -> Result<()> {
  172. trace!(length = datagram.len(), "Sending datagram");
  173. self.stream.send(Message::Binary(datagram.into())).await?;
  174. Ok(())
  175. }
  176. #[tracing::instrument(skip_all)]
  177. async fn receive(&mut self) -> Result<Vec<u8>> {
  178. let datagram = loop {
  179. let message = self
  180. .stream
  181. .try_next()
  182. .await?
  183. .ok_or(anyhow!("WebSocket reading end closed"))?;
  184. match message {
  185. Message::Binary(bytes) => break bytes.to_vec(),
  186. Message::Text(text) => {
  187. bail!("Received unexpected text message: {}", text.as_str())
  188. },
  189. Message::Ping(bytes) => {
  190. // WARNING: There's a slight chance that the pong is lost when this is cancelled!
  191. debug!(ping_length = bytes.len(), "Received ping, responding with a pong");
  192. self.stream.feed(Message::Pong(bytes)).await?;
  193. debug!("Pong sent");
  194. },
  195. Message::Pong(bytes) => {
  196. debug!(pong_length = bytes.len(), "Received pong");
  197. },
  198. Message::Close(close_frame) => {
  199. info!(?close_frame, "Received close");
  200. },
  201. Message::Frame(_) => {
  202. bail!("Received unexpected raw frame");
  203. },
  204. }
  205. };
  206. debug!(datagram_length = datagram.len(), "Received datagram");
  207. Ok(datagram)
  208. }
  209. }
  210. #[derive(Debug, PartialEq)]
  211. enum TransactionState {
  212. None,
  213. Blocked,
  214. AwaitingBeginAck,
  215. Running,
  216. AwaitingCommitAck,
  217. }
  218. struct D2mPingPongFlowRunner {
  219. queues: PayloadQueuesForD2mPingPong,
  220. transaction_state: TransactionState,
  221. reflect_id_counter: u32,
  222. }
  223. impl D2mPingPongFlowRunner {
  224. fn new(queues: PayloadQueuesForD2mPingPong) -> Self {
  225. Self {
  226. queues,
  227. transaction_state: TransactionState::None,
  228. reflect_id_counter: 0,
  229. }
  230. }
  231. async fn run(mut self) -> Result<()> {
  232. // Create a timer that will periodically trigger an outgoing payload
  233. let mut payload_timer = time::interval_at(
  234. Instant::now()
  235. .checked_add(Duration::from_secs(10))
  236. .expect("Oops, apocalypse is near"),
  237. Duration::from_secs(10),
  238. );
  239. // Enter ping-pong flow loop
  240. loop {
  241. let outgoing_payload = tokio::select! {
  242. // Create an outgoing payload when the timer fires
  243. _ = payload_timer.tick() => {
  244. self.create_outgoing_payload()
  245. },
  246. // Process incoming payload
  247. incoming_payload = self.queues.incoming.recv() => {
  248. if let Some(incoming_payload) = incoming_payload {
  249. info!(?incoming_payload, "Received payload");
  250. self.handle_incoming_payload(&incoming_payload)
  251. } else {
  252. break
  253. }
  254. }
  255. };
  256. // Send any outgoing payload
  257. if let Some(outgoing_payload) = outgoing_payload {
  258. info!(?outgoing_payload, "Sending payload");
  259. self.queues.outgoing.send(outgoing_payload).await?;
  260. }
  261. }
  262. Ok(())
  263. }
  264. #[tracing::instrument(skip_all)]
  265. fn handle_incoming_payload(&mut self, incoming_payload: &IncomingPayload) -> Option<OutgoingPayload> {
  266. match incoming_payload {
  267. // Transaction acknowledged: Now running
  268. IncomingPayload::BeginTransactionAck => {
  269. self.transaction_state = TransactionState::Running;
  270. None
  271. },
  272. // Transaction committed: Now none ongoing
  273. IncomingPayload::CommitTransactionAck => {
  274. self.transaction_state = TransactionState::None;
  275. None
  276. },
  277. // Transaction rejected: Retry beginning once we're unblocked
  278. IncomingPayload::TransactionRejected(_) => {
  279. self.transaction_state = TransactionState::Blocked;
  280. None
  281. },
  282. // Another transaction ended: Retry if we were blocked
  283. IncomingPayload::TransactionEnded(_) => {
  284. if self.transaction_state == TransactionState::Blocked {
  285. self.transaction_state = TransactionState::None;
  286. Some(OutgoingPayload::BeginTransaction(BeginTransaction {
  287. encrypted_scope: b"encrypted_scope".to_vec(),
  288. ttl: None,
  289. }))
  290. } else {
  291. None
  292. }
  293. },
  294. _ => None,
  295. }
  296. }
  297. #[tracing::instrument(skip_all)]
  298. fn create_outgoing_payload(&mut self) -> Option<OutgoingPayload> {
  299. trace!(state = ?self.transaction_state);
  300. match self.transaction_state {
  301. // No transaction: Occasionally begin a transaction
  302. TransactionState::None if random::<bool>() => {
  303. self.transaction_state = TransactionState::AwaitingBeginAck;
  304. Some(OutgoingPayload::BeginTransaction(BeginTransaction {
  305. encrypted_scope: b"encrypted_scope".to_vec(),
  306. ttl: None,
  307. }))
  308. },
  309. // Transaction running: Commit
  310. TransactionState::Running => {
  311. self.transaction_state = TransactionState::AwaitingCommitAck;
  312. Some(OutgoingPayload::CommitTransaction)
  313. },
  314. // No trannsaction running: Reflect
  315. TransactionState::None
  316. | TransactionState::Blocked
  317. | TransactionState::AwaitingBeginAck
  318. | TransactionState::AwaitingCommitAck => {
  319. self.reflect_id_counter = self.reflect_id_counter.checked_add(1)?;
  320. Some(OutgoingPayload::Reflect(Reflect {
  321. flags: ReflectFlags(ReflectFlags::EPHEMERAL_MARKER),
  322. reflect_id: self.reflect_id_counter,
  323. envelope: b"envelope".to_vec(),
  324. }))
  325. },
  326. }
  327. }
  328. }
  329. #[tokio::main]
  330. async fn main() -> Result<()> {
  331. // Configure logging
  332. init_stderr_logging(Level::TRACE);
  333. // Create HTTP client
  334. let http_client = https_client_builder().build()?;
  335. // Parse command
  336. let arguments = D2mPingPongCommand::parse();
  337. let config = FullIdentityConfig::from_options(&http_client, arguments.config).await?;
  338. // Create payload queues
  339. let (app_queues, protocol_queues) = {
  340. let incoming_payload = mpsc::channel(4);
  341. let outgoing_payload = mpsc::channel(4);
  342. (
  343. PayloadQueuesForD2mPingPong {
  344. incoming: incoming_payload.1,
  345. outgoing: outgoing_payload.0,
  346. },
  347. PayloadQueuesForProtocol {
  348. incoming: incoming_payload.0,
  349. outgoing: outgoing_payload.1,
  350. },
  351. )
  352. };
  353. // Create D2M protocol and establish a connection
  354. let mut d2m_connection = D2mProtocolRunner::new(
  355. config
  356. .d2m_context()
  357. .expect("Configuration must include D2X configuration"),
  358. )
  359. .await?;
  360. // Create protocol flow runner
  361. let ping_pong_flow_runner = D2mPingPongFlowRunner::new(app_queues);
  362. // Run the handshake flow
  363. d2m_connection.run_handshake_flow().await?;
  364. // Run the protocols
  365. tokio::select! {
  366. _ = d2m_connection.run_payload_flow(protocol_queues) => {}
  367. _ = ping_pong_flow_runner.run() => {}
  368. _ = signal::ctrl_c() => {},
  369. };
  370. // Shut down
  371. d2m_connection.shutdown().await?;
  372. Ok(())
  373. }
  374. #[test]
  375. fn verify_cli() {
  376. use clap::CommandFactory;
  377. D2mPingPongCommand::command().debug_assert();
  378. }