diff --git a/Cargo.lock b/Cargo.lock index 9cfdad9..6a0eb94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -379,6 +379,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "dotenv" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" + [[package]] name = "ecdsa" version = "0.16.8" @@ -609,6 +615,7 @@ name = "irc-e2e" version = "0.1.0" dependencies = [ "base64", + "dotenv", "eyre", "ircparser", "openssl", diff --git a/Cargo.toml b/Cargo.toml index 9ae21de..2ea3f39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] base64 = "0.21.4" +dotenv = "0.15.0" eyre = "0.6.8" ircparser = "0.2.1" openssl = "0.10" diff --git a/src/helpers.rs b/src/helpers.rs index 1c3915e..b9976bd 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -3,18 +3,52 @@ use eyre::Result; use ircparser; use std::sync::mpsc::{self, Receiver, Sender}; +#[derive(Debug)] +struct IrcParseError; + +impl std::fmt::Display for IrcParseError { + fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} + +impl std::error::Error for IrcParseError {} + static MAX_LENGTH: usize = 300; +#[macro_export] +macro_rules! unwrap_or_return_result { + ($e:expr) => { + match $e { + Ok(val) => val, + Err(_) => return Ok(()), + } + }; +} + +#[macro_export] +macro_rules! unwrap_or_return_option { + ($e:expr) => { + match $e { + Some(val) => val, + None => return Ok(()), + } + }; +} + fn forward( - message: &str, + message: String, stream: &Sender, server_local: &str, server_forward: &str, ) -> Result<(), mpsc::SendError> { - if ircparser::parse(message).unwrap()[0].command == "PRIVMSG" { - return Ok(()); + match ircparser::parse(&message) { + Ok(val) => match val[0].command.as_str() { + "PRIVMSG" => stream.send(message), + _ => stream.send(message.replace(&server_local, server_forward)), + }, + Err(_) => stream.send(message.replace(server_local, server_forward)), } - stream.send(message.replace(&server_local, server_forward)) } pub fn bytes_to_privmsg_base64(message: &Vec, reciever: &str) -> String { @@ -45,11 +79,11 @@ pub fn send_key(sender: &Sender, reciever: &str, key: &Vec) -> Resul Ok(()) } -pub fn get_nick(userstring: &str) -> String { +pub fn get_nick(userstring: &str) -> Option { let userstring = userstring.chars().collect::>(); - let start_pos = userstring.iter().position(|&x| x == ':').unwrap() + 1; - let end_pos = userstring.iter().position(|&x| x == '!').unwrap(); - userstring[start_pos..end_pos].iter().collect::() + let start_pos = userstring.iter().position(|&x| x == ':')? + 1; + let end_pos = userstring.iter().position(|&x| x == '!')?; + Some(userstring[start_pos..end_pos].iter().collect::()) } pub fn recieve_message_base64( @@ -65,7 +99,13 @@ pub fn recieve_message_base64( while !message.contains(&end.to_string()) { let recieved_raw = writer_channel_rx.recv()?; - let recieved = &ircparser::parse(&recieved_raw).unwrap()[0]; + let parse_result = ircparser::parse(&recieved_raw); + + let recieved = match parse_result { + Ok(mut val) => val.pop_back().unwrap(), + Err(_) => return Err(IrcParseError.into()), + }; + let begin_source_reciever = format!(":{sender}!"); if recieved.command != "PRIVMSG" || !recieved @@ -75,7 +115,7 @@ pub fn recieve_message_base64( .starts_with(&begin_source_reciever) || recieved.params[0].starts_with("#") { - forward(&recieved_raw, forward_stream, server_local, server_forward)?; + forward(recieved_raw, forward_stream, server_local, server_forward)?; continue; } diff --git a/src/listener_server.rs b/src/listener_server.rs index 4f4f616..40ceae1 100644 --- a/src/listener_server.rs +++ b/src/listener_server.rs @@ -1,17 +1,10 @@ use std::io::{ErrorKind, Read, Write}; -use std::net::TcpListener; -use std::sync::mpsc; +use std::net::{TcpListener, TcpStream}; +use std::sync::mpsc::{self, TryRecvError}; use std::thread; use std::time::Duration; -pub fn listen_to_client(tx: mpsc::Sender, rx: mpsc::Receiver, port: &str) { - let listener = TcpListener::bind("127.0.0.1:".to_string() + port).unwrap(); - let mut stream = listener.accept().unwrap().0; - - stream - .set_nonblocking(true) - .expect("Couldn't set nonblocking"); - +fn stream_handler(tx: &mpsc::Sender, rx: &mpsc::Receiver, mut stream: TcpStream) { loop { let mut buffer: Vec = Vec::new(); let mut buf: [u8; 1] = [0]; @@ -24,6 +17,7 @@ pub fn listen_to_client(tx: mpsc::Sender, rx: mpsc::Receiver, po ErrorKind::WouldBlock => {} _ => { dbg!(_error); + return; } }, } @@ -31,10 +25,14 @@ pub fn listen_to_client(tx: mpsc::Sender, rx: mpsc::Receiver, po Ok(value) => { match stream.write_all(value.as_bytes()) { Ok(_) => {} - Err(_e) => println!("Couldn't send {value}"), + Err(_e) => { + dbg!(_e); + return; + } }; } - Err(_e) => {} + Err(TryRecvError::Empty) => {}, + Err(TryRecvError::Disconnected) => return, } thread::sleep(Duration::from_micros(100)); } @@ -42,3 +40,19 @@ pub fn listen_to_client(tx: mpsc::Sender, rx: mpsc::Receiver, po let _ = tx.send(String::from_utf8_lossy(&buffer).to_string()); } } + +pub fn listen_to_client(tx: mpsc::Sender, rx: mpsc::Receiver, port: String) { + let listener = TcpListener::bind("127.0.0.1:".to_string() + &port) + .expect(&("Couldn't start listener on 127.0.0.1 port ".to_string() + &port)); + loop { + let (stream, ip) = listener.accept().unwrap(); + println!("Got connection from {ip}"); + + stream + .set_nonblocking(true) + .expect("Couldn't set nonblocking"); + + stream_handler(&tx, &rx, stream); + println!("Closed connection with {ip}"); + } +} diff --git a/src/main.rs b/src/main.rs index 506de9c..4780cc4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,12 @@ +use dotenv::{dotenv, vars}; use eyre::Result; use pgp::{Deserializable, SignedPublicKey, SignedSecretKey}; use std::collections::HashMap; +use std::fs; use std::net::{Shutdown, TcpStream}; use std::sync::mpsc; use std::thread; use std::time::Duration; -use std::{env, fs}; mod client_handler; mod encryption; @@ -15,13 +16,24 @@ mod server_handler; mod writer_client; fn main() -> Result<()> { - let mut args = env::args().collect::>(); + dotenv().expect("Couldn't load .env. It probably doesn't exist"); + let mut vars_hashmap = HashMap::new(); - let server = args[1].clone(); - let port = args[2].clone(); + for var in vars() { + vars_hashmap.insert(var.0, var.1); + } - let default_password = String::new(); - let passwd = args.pop().unwrap_or(default_password); + let server = &vars_hashmap["SERVER"]; + + let default_passwd = String::new(); + + let port = match vars_hashmap.get("PORT") { + Some(val) => val, + None => "6666", + } + .to_string(); + + let passwd = vars_hashmap.get("PASSWD").unwrap_or(&default_passwd); let stream = TcpStream::connect(format!("{server}:6697"))?; @@ -52,7 +64,7 @@ fn main() -> Result<()> { let (writer_channel_recv_tx, writer_channel_rx) = mpsc::channel(); thread::spawn(move || { - listener_server::listen_to_client(listener_channel_send_tx, listener_channel_recv_rx, &port) + listener_server::listen_to_client(listener_channel_send_tx, listener_channel_recv_rx, port) }); let tmp_server = server.clone(); thread::spawn(|| { @@ -68,16 +80,18 @@ fn main() -> Result<()> { loop { match listener_channel_rx.try_recv() { - Ok(message) => client_handler::handle_message_from_client( - &message, - &public_key, - &server, - &mut keys, - &writer_channel_tx, - &writer_channel_rx, - &listener_channel_tx, - &listener_channel_rx, - )?, + Ok(message) => { + let _ = client_handler::handle_message_from_client( + &message, + &public_key, + server, + &mut keys, + &writer_channel_tx, + &writer_channel_rx, + &listener_channel_tx, + &listener_channel_rx, + ); + } Err(error) => match error { mpsc::TryRecvError::Empty => {} mpsc::TryRecvError::Disconnected => panic!("listener_channel_rx disconnected"), @@ -85,18 +99,20 @@ fn main() -> Result<()> { }; match writer_channel_rx.try_recv() { - Ok(message) => server_handler::handle_message_from_server( - &message, - &public_key, - &secret_key, - &server, - &passwd, - &mut keys, - &writer_channel_tx, - &writer_channel_rx, - &listener_channel_tx, - &listener_channel_rx, - )?, + Ok(message) => { + let _ = server_handler::handle_message_from_server( + &message, + &public_key, + &secret_key, + server, + passwd, + &mut keys, + &writer_channel_tx, + &writer_channel_rx, + &listener_channel_tx, + &listener_channel_rx, + ); + } Err(error) => match error { mpsc::TryRecvError::Empty => {} mpsc::TryRecvError::Disconnected => panic!("writer_channel_rx disconnected"), diff --git a/src/server_handler.rs b/src/server_handler.rs index 92aea38..c5e545e 100644 --- a/src/server_handler.rs +++ b/src/server_handler.rs @@ -1,3 +1,5 @@ +use crate::unwrap_or_return_option; +use crate::unwrap_or_return_result; use crate::{encryption, helpers}; use eyre::Result; use pgp::{Deserializable, SignedPublicKey, SignedSecretKey}; @@ -9,7 +11,7 @@ fn forward( listener_channel_tx: &Sender, server: &str, ) -> Result<(), mpsc::SendError> { - listener_channel_tx.send(message.replace(&server, "127.0.0.1")) + listener_channel_tx.send(message.replace(server, "127.0.0.1")) } pub fn handle_message_from_server( @@ -24,14 +26,14 @@ pub fn handle_message_from_server( listener_channel_tx: &Sender, _listener_channel_rx: &Receiver, ) -> Result<()> { - let recieved_parsed = &ircparser::parse(&recieved).unwrap()[0]; + let recieved_parsed = &unwrap_or_return_result!(ircparser::parse(recieved))[0]; if recieved_parsed.command != "PRIVMSG" || recieved_parsed .params .get(0) .unwrap_or(&String::new()) - .starts_with("#") + .starts_with('#') { forward(recieved, listener_channel_tx, server)?; return Ok(()); @@ -39,30 +41,32 @@ pub fn handle_message_from_server( dbg!(&recieved_parsed); + let source = unwrap_or_return_option!(&recieved_parsed.source); + if recieved_parsed.params[1] == "START_MESSAGE" { - let sender = helpers::get_nick(&recieved_parsed.source.clone().unwrap()); + let sender = unwrap_or_return_option!(helpers::get_nick(source)); let message = helpers::recieve_message_base64( writer_channel_rx, listener_channel_tx, - &server, + server, "127.0.0.1", &sender, "END_MESSAGE", )?; // Get - let message = encryption::decrypt(&secret_key, &message, &passwd)?; // Decrypt + let message = encryption::decrypt(secret_key, &message, passwd)?; // Decrypt listener_channel_tx.send(recieved.replace("START_MESSAGE", &message))?; // Send } else if recieved_parsed.params[1] == "START_KEY" { - let sender = helpers::get_nick(&recieved_parsed.source.clone().unwrap()); - let to_send = helpers::bytes_to_privmsg_base64(&public_key, &sender); + let sender = unwrap_or_return_option!(helpers::get_nick(source)); + let to_send = helpers::bytes_to_privmsg_base64(public_key, &sender); writer_channel_tx.send(to_send)?; writer_channel_tx.send(format!("PRIVMSG {sender} END_KEY\r\n"))?; let foreign_key = helpers::recieve_message_base64( writer_channel_rx, listener_channel_tx, - &server, + server, "127.0.0.1", &sender, "END_KEY", diff --git a/src/writer_client.rs b/src/writer_client.rs index 4b413d3..c69e017 100644 --- a/src/writer_client.rs +++ b/src/writer_client.rs @@ -35,7 +35,7 @@ pub fn write_to_server( } Err(_error) => match _error.io_error() { None => { - dbg!(_error); + dbg!(_error.ssl_error()); } Some(error) => match error.kind() { ErrorKind::WouldBlock => {}