use futures_util::{SinkExt, StreamExt}; use tokio::sync::mpsc; use tokio_tungstenite::{connect_async, tungstenite::Message}; use tracing::{error, info, warn}; use crate::config::Config; use crate::messages::{parse_socketio, AgentMessage, ServerMessage}; use crate::pty_manager::PtyManager; pub async fn run_agent(config: Config) { let pty_manager = PtyManager::new(); let mut retry_delay = 1u64; loop { info!("Connecting to {}...", config.server_url); match connect_and_run(&config, &pty_manager).await { Ok(()) => { info!("Connection closed normally"); retry_delay = 1; } Err(e) => { error!("Connection error: {}", e); } } let delay = retry_delay.min(60); warn!("Reconnecting in {}s...", delay); tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; retry_delay = (retry_delay * 2).min(60); } } async fn connect_and_run(config: &Config, pty_manager: &PtyManager) -> Result<(), String> { // Socket.IO handshake: first GET /socket.io/?EIO=4&transport=polling // Then upgrade to WebSocket with /socket.io/?EIO=4&transport=websocket let ws_url = format!( "{}/socket.io/?EIO=4&transport=websocket", config.server_url.replace("ws://", "ws://").replace("wss://", "wss://") ); let (ws_stream, _) = connect_async(&ws_url) .await .map_err(|e| format!("WebSocket connect failed: {}", e))?; info!("Connected to server"); let (mut write, mut read) = ws_stream.split(); // Socket.IO handshake: send "40" to connect to namespace write .send(Message::Text("40/ws,".to_string())) .await .map_err(|e| format!("Handshake failed: {}", e))?; // Send register let register = AgentMessage::Register { token: config.token.clone(), hostname: config.name.clone(), os: std::env::consts::OS.to_string(), }; // Wait for namespace connect confirmation // Then send register event on /ws namespace let register_msg = format!("42/ws,{}", ®ister.to_socketio()[2..]); write .send(Message::Text(register_msg)) .await .map_err(|e| format!("Register failed: {}", e))?; info!("Registered as {}", config.name); // Channel for PTY output let (output_tx, mut output_rx) = mpsc::unbounded_channel::<(String, String)>(); // Heartbeat task let heartbeat_interval = config.heartbeat_interval; let (heartbeat_tx, mut heartbeat_rx) = mpsc::unbounded_channel::<()>(); tokio::spawn(async move { loop { tokio::time::sleep(tokio::time::Duration::from_secs(heartbeat_interval)).await; if heartbeat_tx.send(()).is_err() { break; } } }); loop { tokio::select! { // Receive from server msg = read.next() => { match msg { Some(Ok(Message::Text(text))) => { // Handle Socket.IO ping/pong if text == "2" { let _ = write.send(Message::Text("3".to_string())).await; continue; } // Parse /ws namespace messages let clean = if text.starts_with("42/ws,") { format!("42{}", &text[6..]) } else { text.clone() }; if let Some(server_msg) = parse_socketio(&clean) { match server_msg { ServerMessage::Registered { machine_id } => { info!("Registered with machine_id: {}", machine_id); } ServerMessage::SessionStart { session_id, command } => { info!("Starting session {}: {}", session_id, command); let tx = output_tx.clone(); match pty_manager.start_session(session_id.clone(), &command, tx).await { Ok(()) => { let status_msg = AgentMessage::SessionStatus { session_id, status: "running".to_string(), }; let msg = format!("42/ws,{}", &status_msg.to_socketio()[2..]); let _ = write.send(Message::Text(msg)).await; } Err(e) => error!("Failed to start session: {}", e), } } ServerMessage::SessionStop { session_id } => { info!("Stopping session {}", session_id); let _ = pty_manager.stop_session(&session_id).await; let status_msg = AgentMessage::SessionStatus { session_id, status: "stopped".to_string(), }; let msg = format!("42/ws,{}", &status_msg.to_socketio()[2..]); let _ = write.send(Message::Text(msg)).await; } ServerMessage::SessionInput { session_id, input } => { if let Err(e) = pty_manager.write_to_session(&session_id, &input).await { error!("Failed to write to session: {}", e); } } ServerMessage::Error { message } => { error!("Server error: {}", message); return Err(message); } } } } Some(Ok(Message::Close(_))) | None => { return Ok(()); } Some(Err(e)) => { return Err(format!("WebSocket error: {}", e)); } _ => {} } } // PTY output -> send to server Some((session_id, output)) = output_rx.recv() => { if output.is_empty() { // Session ended let status_msg = AgentMessage::SessionStatus { session_id, status: "stopped".to_string(), }; let msg = format!("42/ws,{}", &status_msg.to_socketio()[2..]); let _ = write.send(Message::Text(msg)).await; } else { let output_msg = AgentMessage::SessionOutput { session_id, output, }; let msg = format!("42/ws,{}", &output_msg.to_socketio()[2..]); let _ = write.send(Message::Text(msg)).await; } } // Heartbeat Some(()) = heartbeat_rx.recv() => { let hb = AgentMessage::Heartbeat { timestamp: format!("{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs()), }; let msg = format!("42/ws,{}", &hb.to_socketio()[2..]); let _ = write.send(Message::Text(msg)).await; } } } }