use axum::extract::ws::{Message, WebSocket}; use axum::extract::{Query, State, WebSocketUpgrade}; use axum::response::IntoResponse; use futures_util::{SinkExt, StreamExt}; use serde::Deserialize; use tokio::sync::mpsc; use ericrfb::framebuffer::Framebuffer; use ericrfb::handshake::Config; use ericrfb::input; use ericrfb::msg; use ericrfb::proto::RGB332_LUT; use ericrfb::session::{ActiveSession, Event}; use crate::AppState; // --------------------------------------------------------------------------- // WS binary protocol tags // --------------------------------------------------------------------------- // Proxy → Browser const TAG_BLIT: u8 = 0x01; const TAG_RESIZE: u8 = 0x03; // Browser → Proxy const TAG_KEY_PRESS: u8 = 0x10; const TAG_KEY_RELEASE: u8 = 0x11; const TAG_POINTER: u8 = 0x12; const TAG_CTRL_ALT_DEL: u8 = 0x13; #[derive(Deserialize)] pub struct WsQuery { pub applet_id: String, #[serde(default = "default_port")] pub port: u16, } fn default_port() -> u16 { 443 } pub async fn handle_ws( ws: WebSocketUpgrade, State(state): State, Query(query): Query, ) -> impl IntoResponse { ws.on_upgrade(move |socket| run_session(socket, state, query)) } async fn run_session(socket: WebSocket, state: AppState, query: WsQuery) { let cfg = Config::new(&state.config.omniview.host, query.port, &query.applet_id); tracing::info!( "WS session starting: {}:{}", state.config.omniview.host, query.port ); // Connect to OmniView in a blocking task (handshake is sync IO) let session = match tokio::task::spawn_blocking(move || { ActiveSession::connect(&cfg, &[7, 5, 1, 0, -250]) }) .await { Ok(Ok(s)) => s, Ok(Err(e)) => { tracing::error!("OmniView connect failed: {e}"); return; } Err(e) => { tracing::error!("spawn_blocking panicked: {e}"); return; } }; tracing::info!( "Connected to OmniView: {}x{}", session.framebuffer.width, session.framebuffer.height ); let (ws_tx, ws_rx) = socket.split(); let (blit_tx, blit_rx) = mpsc::channel::(64); // Channel for input events from browser → OmniView writer let (input_tx, input_rx) = mpsc::channel::(64); // Task: forward blit messages to WebSocket let ws_send_task = tokio::spawn(forward_ws_send(ws_tx, blit_rx)); // Task: receive input from WebSocket let ws_recv_task = tokio::spawn(forward_ws_recv(ws_rx, input_tx)); // Task: OmniView session pump (blocking) let pump_task = tokio::task::spawn_blocking(move || run_pump(session, blit_tx, input_rx)); // Wait for any task to finish (on error or disconnect) tokio::select! { r = ws_send_task => { tracing::debug!("ws_send finished: {r:?}"); } r = ws_recv_task => { tracing::debug!("ws_recv finished: {r:?}"); } r = pump_task => { tracing::debug!("pump finished: {r:?}"); } } tracing::info!("WS session ended"); } // --------------------------------------------------------------------------- // OmniView pump (runs on blocking thread) // --------------------------------------------------------------------------- enum InputEvent { KeyPress(u8), KeyRelease(u8), Pointer { x: u16, y: u16, mask: u8 }, CtrlAltDel, } fn run_pump( mut session: ActiveSession, blit_tx: mpsc::Sender, mut input_rx: mpsc::Receiver, ) { // Send initial resize message let w = session.framebuffer.width; let h = session.framebuffer.height; let _ = blit_tx.blocking_send(make_resize_msg(w, h)); loop { // Drain any pending input events while let Ok(evt) = input_rx.try_recv() { if let Err(e) = handle_input(&mut session, evt) { tracing::error!("input error: {e}"); return; } } // Process one server message match session.process_one() { Ok(Some(Event::FramebufferDirty)) => { // Send full framebuffer as RGBA blit let msg = make_full_blit(&session.framebuffer); if blit_tx.blocking_send(msg).is_err() { return; // WS closed } // Request next update if let Err(e) = session.request_update() { tracing::error!("request_update error: {e}"); return; } } Ok(Some(Event::Resize { width, height })) => { let _ = blit_tx.blocking_send(make_resize_msg(width, height)); } Ok(_) => {} Err(e) => { tracing::error!("session error: {e}"); return; } } } } fn handle_input(session: &mut ActiveSession, evt: InputEvent) -> Result<(), String> { match evt { InputEvent::KeyPress(sc) => { input::write_key_press(&mut session.writer, sc).map_err(|e| e.to_string()) } InputEvent::KeyRelease(sc) => { input::write_key_release(&mut session.writer, sc).map_err(|e| e.to_string()) } InputEvent::Pointer { x, y, mask } => { msg::write_pointer_event(&mut session.writer, x, y, mask).map_err(|e| e.to_string()) } InputEvent::CtrlAltDel => { input::write_ctrl_alt_del(&mut session.writer).map_err(|e| e.to_string()) } } } // --------------------------------------------------------------------------- // Binary message builders // --------------------------------------------------------------------------- fn make_full_blit(fb: &Framebuffer) -> Message { let w = fb.width; let h = fb.height; // Header: tag(1) + x(2) + y(2) + w(2) + h(2) = 9 bytes let mut buf = Vec::with_capacity(9 + (w as usize * h as usize * 4)); buf.push(TAG_BLIT); buf.extend_from_slice(&0u16.to_be_bytes()); // x buf.extend_from_slice(&0u16.to_be_bytes()); // y buf.extend_from_slice(&w.to_be_bytes()); buf.extend_from_slice(&h.to_be_bytes()); // RGBA pixels for &px in &fb.pixels { buf.extend_from_slice(&RGB332_LUT[px as usize]); } Message::Binary(buf) } fn make_resize_msg(w: u16, h: u16) -> Message { let mut buf = Vec::with_capacity(5); buf.push(TAG_RESIZE); buf.extend_from_slice(&w.to_be_bytes()); buf.extend_from_slice(&h.to_be_bytes()); Message::Binary(buf) } // --------------------------------------------------------------------------- // WebSocket forwarding tasks // --------------------------------------------------------------------------- async fn forward_ws_send( mut tx: futures_util::stream::SplitSink, mut rx: mpsc::Receiver, ) { while let Some(msg) = rx.recv().await { if tx.send(msg).await.is_err() { break; } } } async fn forward_ws_recv( mut rx: futures_util::stream::SplitStream, tx: mpsc::Sender, ) { while let Some(Ok(msg)) = rx.next().await { match msg { Message::Binary(data) if !data.is_empty() => { if let Some(evt) = parse_input(&data) && tx.send(evt).await.is_err() { break; } } Message::Close(_) => break, _ => {} } } } fn parse_input(data: &[u8]) -> Option { match data[0] { TAG_KEY_PRESS if data.len() >= 2 => Some(InputEvent::KeyPress(data[1])), TAG_KEY_RELEASE if data.len() >= 2 => Some(InputEvent::KeyRelease(data[1])), TAG_POINTER if data.len() >= 6 => { let x = u16::from_be_bytes([data[1], data[2]]); let y = u16::from_be_bytes([data[3], data[4]]); let mask = data[5]; Some(InputEvent::Pointer { x, y, mask }) } TAG_CTRL_ALT_DEL => Some(InputEvent::CtrlAltDel), _ => None, } }