diff --git a/main.rs b/main.rs new file mode 100644 index 0000000..e69de29 diff --git a/public/index.css b/public/index.css index 211c114..8b5c750 100644 --- a/public/index.css +++ b/public/index.css @@ -10,10 +10,6 @@ html { height: 100%; } -li { - color: #fff; -} - body { display: flex; flex-direction: row; diff --git a/public/index.html b/public/index.html index 819b5a0..a7f1b58 100644 --- a/public/index.html +++ b/public/index.html @@ -4,7 +4,17 @@ Hello World! - + @@ -17,14 +27,6 @@
-
- -
- -
- - -
diff --git a/public/index.js b/public/index.js deleted file mode 100644 index 58f68f1..0000000 --- a/public/index.js +++ /dev/null @@ -1,14 +0,0 @@ -const socket = new WebSocket("ws://localhost:8080/websocket"); - -socket.addEventListener("message", (event) => { - let messages = document.getElementById("messages"); - let item = document.createElement("li"); - item.innerHTML = event.data; - messages.appendChild(item); -}); - -function myFunction() { - socket.send(document.getElementById("inp").value); - - return false; -} diff --git a/src/main.rs b/src/main.rs index cdbc08a..003a02b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,44 +6,36 @@ mod websoket_connection; use std::time::Duration; use std::{path::Path, str::FromStr}; -use tokio::io::{self, AsyncWriteExt}; +use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use tokio::time; -use crate::websoket_connection::{FrameType, WebsocketRead, WebsocketWrite}; +use crate::websoket_connection::{FrameType, WebsocketConnection}; use crate::{ request::{Connection, ServerPath}, response::{Response, ResponseCode, ResponseHeader}, }; -use tokio::sync; - #[tokio::main] async fn main() -> tokio::io::Result<()> { let listener = TcpListener::bind("127.0.0.1:8080").await?; - let (sender, _) = sync::broadcast::channel(16); loop { let (stream, _) = listener.accept().await?; - let receiver = sender.subscribe(); - let sender = sender.clone(); - tokio::spawn(handle_connection(stream, receiver, sender)); + + tokio::spawn(handle_connection(stream)); } } -async fn handle_connection( - stream: TcpStream, - receiver: sync::broadcast::Receiver, - sender: sync::broadcast::Sender, -) -> tokio::io::Result<()> { +async fn handle_connection(stream: TcpStream) -> tokio::io::Result<()> { if let Some(ws) = handle_http_connection(stream).await? { - handle_websocket(ws, receiver, sender).await? + handle_websocket(ws).await? } Ok(()) } async fn handle_http_connection( mut stream: TcpStream, -) -> tokio::io::Result> { +) -> tokio::io::Result> { let mut timeout = 500; loop { let req = match time::timeout( @@ -62,8 +54,6 @@ async fn handle_http_connection( } }; - println!("{req:?}"); - let matchable = req.path.to_matchable(); let response = match matchable.as_slice() { @@ -73,7 +63,7 @@ async fn handle_http_connection( Err(_) => Response::new().with_code(ResponseCode::NotFound), } } - ["websocket"] => match websoket_connection::initialize_connection(req, stream).await { + ["websocket"] => match WebsocketConnection::initialize_connection(req, stream).await { Ok(ws) => { return Ok(Some(ws)); } @@ -101,45 +91,15 @@ async fn handle_http_connection( Ok(None) } -async fn handle_websocket( - mut web_socket: (WebsocketRead, WebsocketWrite), - receiver: sync::broadcast::Receiver, - sender: sync::broadcast::Sender, -) -> tokio::io::Result<()> { - tokio::spawn(broadcast_message(web_socket.1, receiver)); - +async fn handle_websocket(mut web_socket: WebsocketConnection) -> tokio::io::Result<()> { loop { - let message = web_socket.0.read_next_message().await?; + let message = web_socket.read_next_message().await?; if message.frame_type == FrameType::TextFrame { - let s = String::from_utf8_lossy(&message.data).to_string(); - println!("{}", s); - let _ = sender.send(s); - } - } -} - -enum BroadcastError { - IoError(io::Error), - BroadcastError(sync::broadcast::error::RecvError), -} - -async fn broadcast_message( - mut write: WebsocketWrite, - mut receiver: sync::broadcast::Receiver, -) -> Result<(), BroadcastError> { - loop { - let new_message = match receiver.recv().await { - Ok(s) => s, - Err(e) => return Err(BroadcastError::BroadcastError(e)), - }; - - match write - .send_message(FrameType::TextFrame, new_message.as_bytes()) - .await - { - Ok(()) => {} - Err(e) => return Err(BroadcastError::IoError(e)), + println!("{}", String::from_utf8_lossy(&message.data)); + web_socket + .send_message(FrameType::TextFrame, "message_received".as_bytes()) + .await?; } } } diff --git a/src/response.rs b/src/response.rs index e301211..7cb2fc6 100644 --- a/src/response.rs +++ b/src/response.rs @@ -25,7 +25,7 @@ impl Response { output.extend_from_slice(format!("Content-Length: {}", self.data.len()).as_bytes()); output.extend_from_slice(b"\r\n\r\n"); output.extend_from_slice(&self.data); - } else if !self.headers.is_empty() { + } else { output.extend_from_slice(b"\r\n"); } @@ -97,9 +97,6 @@ impl Response { Some(a) if a == OsStr::new("css") => { ContentType::Text(crate::shared_enums::TextType::Css) } - Some(a) if a == OsStr::new("js") => { - ContentType::Text(crate::shared_enums::TextType::Javascript) - } Some(_) | None => { return Err(io::Error::new( io::ErrorKind::InvalidInput, diff --git a/src/websoket_connection.rs b/src/websoket_connection.rs index 85eaad3..de2be4e 100644 --- a/src/websoket_connection.rs +++ b/src/websoket_connection.rs @@ -3,21 +3,14 @@ use crate::{ response::Response, }; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio::{ - io::{self, AsyncReadExt, AsyncWriteExt}, - net::tcp::{OwnedReadHalf, OwnedWriteHalf}, -}; use base64::prelude::*; use sha1::{Digest, Sha1}; -pub struct WebsocketRead { - read: OwnedReadHalf, -} - -pub struct WebsocketWrite { - write: OwnedWriteHalf, +pub struct WebsocketConnection { + stream: TcpStream, } struct DataBlock { @@ -48,10 +41,11 @@ pub enum FrameType { OtherNonControl(u8), } -impl WebsocketWrite { +impl WebsocketConnection { pub async fn send_message(&mut self, frame_type: FrameType, data: &[u8]) -> io::Result<()> { - let mut header = Vec::with_capacity(14); + let mut header = Vec::with_capacity(14); // Max header size for 127-length payload + // First byte: FIN (1) + RSV1-3 (000) + opcode let opcode = match frame_type { FrameType::TextFrame => 0x1, FrameType::BinaryFrame => 0x2, @@ -62,6 +56,7 @@ impl WebsocketWrite { }; header.push(0b1000_0000 | opcode); // FIN = 1 + // Second byte: MASK bit = 0 (server -> client frames are NOT masked) let payload_len = data.len(); if payload_len < 126 { header.push(payload_len as u8); @@ -73,15 +68,14 @@ impl WebsocketWrite { header.extend_from_slice(&(payload_len as u64).to_be_bytes()); } - self.write.write_all(&header).await?; - self.write.write_all(data).await?; - self.write.flush().await?; + // Send header + payload + self.stream.write_all(&header).await?; + self.stream.write_all(data).await?; + self.stream.flush().await?; Ok(()) } -} -impl WebsocketRead { pub async fn read_next_message(&mut self) -> io::Result { let first_line = self.parse_single_block().await?; @@ -120,7 +114,7 @@ impl WebsocketRead { async fn parse_single_block(&mut self) -> io::Result { let mut first_line: [u8; 2] = [0; 2]; - self.read.read_exact(&mut first_line).await?; + self.stream.read_exact(&mut first_line).await?; let get_bool = |index: u8, byte: u8| -> bool { byte & (1 << index) != 0 }; @@ -146,15 +140,19 @@ impl WebsocketRead { let mask = get_bool(7, first_line[1]); let length = match first_line[1] & 0b01111111 { - 126 => self.read.read_u16().await? as u64, - 127 => self.read.read_u64().await?, + 126 => self.stream.read_u16().await? as u64, + 127 => self.stream.read_u64().await?, other => other as u64, }; - let masking_key = if mask { self.read.read_u32().await? } else { 0 }; + let masking_key = if mask { + self.stream.read_u32().await? + } else { + 0 + }; let mut message_data = vec![0u8; length as usize]; - self.read.read_exact(&mut message_data).await?; + self.stream.read_exact(&mut message_data).await?; if mask { Self::unmask_block(&mut message_data, masking_key); @@ -169,78 +167,74 @@ impl WebsocketRead { data: message_data, }) } -} -pub async fn initialize_connection( - req: Request, - mut stream: TcpStream, -) -> tokio::io::Result<(WebsocketRead, WebsocketWrite)> { - let (mut upgrade, mut connection, mut key_exists) = (false, false, false); - let mut key_val: Box = "".into(); - for i in req.headers { - match i { - RequestHeader::Upgrade(upgrad) => { - if let Some(upg) = upgrad.first() - && upg.protocol == Protocol::Websocket - { - upgrade = true; + pub async fn initialize_connection( + req: Request, + mut stream: TcpStream, + ) -> tokio::io::Result { + let (mut upgrade, mut connection, mut key_exists) = (false, false, false); + let mut key_val: Box = "".into(); + + for i in req.headers { + match i { + RequestHeader::Upgrade(upgrad) => { + if let Some(upg) = upgrad.first() + && upg.protocol == Protocol::Websocket + { + upgrade = true; + } } - } - RequestHeader::Connection(con) => { - if con == Connection::Upgrade { - connection = true; - } else if let Connection::Other(c) = con - && c.contains("Upgrade") - { - connection = true; + RequestHeader::Connection(con) => { + if con == Connection::Upgrade { + connection = true; + } else if let Connection::Other(c) = con + && c.contains("Upgrade") + { + connection = true; + } } - } - RequestHeader::Other { name, value } => { - if name == "Sec-WebSocket-Key".into() { - key_val = value.clone(); - key_exists = true; + RequestHeader::Other { name, value } => { + if name == "Sec-WebSocket-Key".into() { + key_val = value.clone(); + key_exists = true; + } } + _ => (), } - _ => (), + } + + if upgrade && connection && key_exists { + let magic_val = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut hasher = Sha1::new(); + hasher.update(key_val.as_bytes()); + hasher.update(magic_val); + + let result = hasher.finalize(); + let result = BASE64_STANDARD.encode(result); + + let rep = Response::new() + .with_code(crate::response::ResponseCode::SwitchingProtocols) + .with_header(crate::response::ResponseHeader::Upgrade(Upgrade { + protocol: Protocol::Websocket, + version: None, + })) + .with_header(crate::response::ResponseHeader::Connection( + Connection::Upgrade, + )) + .with_header(crate::response::ResponseHeader::Other { + header_name: "Sec-WebSocket-Accept".into(), + header_value: result.into(), + }); + rep.respond(&mut stream).await?; + + Ok(Self { stream }) + } else { + Response::new() + .with_code(crate::response::ResponseCode::BadRequest) + .respond(&mut stream) + .await?; + stream.flush().await?; + Err(io::Error::new(io::ErrorKind::InvalidData, "Wrong request")) } } - - if upgrade && connection && key_exists { - let magic_val = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - let mut hasher = Sha1::new(); - hasher.update(key_val.as_bytes()); - hasher.update(magic_val); - - let result = hasher.finalize(); - let result = BASE64_STANDARD.encode(result); - - let rep = Response::new() - .with_code(crate::response::ResponseCode::SwitchingProtocols) - .with_header(crate::response::ResponseHeader::Upgrade(Upgrade { - protocol: Protocol::Websocket, - version: None, - })) - .with_header(crate::response::ResponseHeader::Connection( - Connection::Upgrade, - )) - .with_header(crate::response::ResponseHeader::Other { - header_name: "Sec-WebSocket-Accept".into(), - header_value: result.into(), - }); - rep.respond(&mut stream).await?; - - let (read_halve, write_halve) = stream.into_split(); - - Ok(( - WebsocketRead { read: read_halve }, - WebsocketWrite { write: write_halve }, - )) - } else { - Response::new() - .with_code(crate::response::ResponseCode::BadRequest) - .respond(&mut stream) - .await?; - stream.flush().await?; - Err(io::Error::new(io::ErrorKind::InvalidData, "Wrong request")) - } }