diff --git a/src/main.rs b/src/main.rs index 04f3467..eb7fdeb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use rand::{thread_rng, RngCore}; use std::env; use std::fs::remove_file; -use std::io::{BufRead, BufReader, Write}; +use std::io::{BufRead, BufReader, Read, Write}; use std::os::unix::net::{UnixListener, UnixStream}; use std::path::PathBuf; use std::process::Command; @@ -61,41 +61,46 @@ fn create_socket() -> PathBuf { temp_dir } -fn handle_client(stream: UnixStream) { - let reader = BufReader::new(stream); - +fn handle_client(mut stream: UnixStream) { let start_time = SystemTime::now(); const TIMEOUT: Duration = Duration::from_secs(3); - let handshake_msg = String::from("handshake"); - let mut lines = reader.lines(); + stream + .set_read_timeout(Some(Duration::from_millis(1))) + .unwrap(); - match lines.next() { - Some(Ok(msg)) if msg == handshake_msg => { - if SystemTime::now().duration_since(start_time).unwrap() >= TIMEOUT { - println!("Client took too long to send first message"); - return; + loop { + let mut response = String::new(); + + // Change to use buffered, read line until timeout, and catch timeout as below (hopefully can change timeout after initial read) + + match stream.read_to_string(&mut response) { + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {} + Err(e) => { + println!("Error reading {}", e); } + _ => {} } - Some(Ok(msg)) => { - println!("First line isn't handshake: {}", msg); - return; - } - Some(Err(e)) => { - println!("Failed to get first line: {}", e); - return; - } - None => { - println!("Client terminated before first message"); - // If the stream ends here, abort + + if response.trim_end() == "handshake" { + println!("Got correct handshake"); + break; + } + + if SystemTime::now().duration_since(start_time).unwrap() >= TIMEOUT { + println!("Timeout expired, killing connection"); return; } + + // thread::sleep(Duration::from_secs(1)); } - println!("Got correct handshake"); + stream.set_read_timeout(None).unwrap(); - for line in lines { + let reader = BufReader::new(stream); + + for line in reader.lines() { dbg!(line.unwrap()); } @@ -105,11 +110,17 @@ fn handle_client(stream: UnixStream) { fn client(path: String) { let mut stream = UnixStream::connect(path).unwrap(); - writeln!(stream, "handshake").unwrap(); + thread::sleep(Duration::from_secs(2)); - thread::sleep(Duration::from_secs(1)); + writeln!(stream, "handshake").unwrap(); + writeln!(stream, "handshake1").unwrap(); + writeln!(stream, "handshake2").unwrap(); + + thread::sleep(Duration::from_secs(2)); writeln!(stream, "Hello world").unwrap(); + + thread::sleep(Duration::from_secs(2)); } fn monitor_child(socket_path: PathBuf) {