1
Fork 0

Improve handshake implementation a little

This commit is contained in:
Jake Howard 2023-08-20 10:45:49 +01:00
parent 59e0c934e2
commit 3369862d8d
Signed by: jake
GPG key ID: 57AFB45680EDD477

View file

@ -1,7 +1,7 @@
use rand::{thread_rng, RngCore}; use rand::{thread_rng, RngCore};
use std::env; use std::env;
use std::fs::remove_file; use std::fs::remove_file;
use std::io::{BufRead, BufReader, Write}; use std::io::{BufRead, BufReader, Read, Write};
use std::os::unix::net::{UnixListener, UnixStream}; use std::os::unix::net::{UnixListener, UnixStream};
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Command; use std::process::Command;
@ -61,41 +61,46 @@ fn create_socket() -> PathBuf {
temp_dir temp_dir
} }
fn handle_client(stream: UnixStream) { fn handle_client(mut stream: UnixStream) {
let reader = BufReader::new(stream);
let start_time = SystemTime::now(); let start_time = SystemTime::now();
const TIMEOUT: Duration = Duration::from_secs(3); const TIMEOUT: Duration = Duration::from_secs(3);
let handshake_msg = String::from("handshake");
let mut lines = reader.lines(); stream
.set_read_timeout(Some(Duration::from_millis(1)))
.unwrap();
match lines.next() { loop {
Some(Ok(msg)) if msg == handshake_msg => { let mut response = String::new();
if SystemTime::now().duration_since(start_time).unwrap() >= TIMEOUT {
println!("Client took too long to send first message"); // Change to use buffered, read line until timeout, and catch timeout as below (hopefully can change timeout after initial read)
return;
} match stream.read_to_string(&mut response) {
} Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Some(Ok(msg)) => { Err(e) => {
println!("First line isn't handshake: {}", msg); println!("Error reading {}", e);
return;
}
Some(Err(e)) => {
println!("Failed to get first line: {}", e);
return;
}
None => {
println!("Client terminated before first message");
// If the stream ends here, abort
return;
} }
_ => {}
} }
if response.trim_end() == "handshake" {
println!("Got correct handshake"); println!("Got correct handshake");
break;
}
for line in lines { if SystemTime::now().duration_since(start_time).unwrap() >= TIMEOUT {
println!("Timeout expired, killing connection");
return;
}
// thread::sleep(Duration::from_secs(1));
}
stream.set_read_timeout(None).unwrap();
let reader = BufReader::new(stream);
for line in reader.lines() {
dbg!(line.unwrap()); dbg!(line.unwrap());
} }
@ -105,11 +110,17 @@ fn handle_client(stream: UnixStream) {
fn client(path: String) { fn client(path: String) {
let mut stream = UnixStream::connect(path).unwrap(); let mut stream = UnixStream::connect(path).unwrap();
writeln!(stream, "handshake").unwrap(); thread::sleep(Duration::from_secs(2));
thread::sleep(Duration::from_secs(1)); writeln!(stream, "handshake").unwrap();
writeln!(stream, "handshake1").unwrap();
writeln!(stream, "handshake2").unwrap();
thread::sleep(Duration::from_secs(2));
writeln!(stream, "Hello world").unwrap(); writeln!(stream, "Hello world").unwrap();
thread::sleep(Duration::from_secs(2));
} }
fn monitor_child(socket_path: PathBuf) { fn monitor_child(socket_path: PathBuf) {