diff --git a/main.rs b/main.rs deleted file mode 100644 index e69de29..0000000 diff --git a/public/index.html b/public/index.html index a7f1b58..64bab5b 100644 --- a/public/index.html +++ b/public/index.html @@ -4,17 +4,7 @@ Hello World! - + @@ -27,6 +17,14 @@
+
+ +
+ +
+ + +
diff --git a/public/index.js b/public/index.js new file mode 100644 index 0000000..c6dcef4 --- /dev/null +++ b/public/index.js @@ -0,0 +1,12 @@ +const socket = new WebSocket("ws://localhost:8080/websocket"); + +socket.addEventListener("message", (event) => { + let messages = document.getElementById("messages"); + messages.appendChild(`
  • ${event.data}
  • `) +}); + +function myFunction() { + socket.send(document.getElementById("inp").value); + + return false; +} diff --git a/src/main.rs b/src/main.rs index 003a02b..9b07f6a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,36 +6,44 @@ mod websoket_connection; use std::time::Duration; use std::{path::Path, str::FromStr}; -use tokio::io::AsyncWriteExt; +use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::time; -use crate::websoket_connection::{FrameType, WebsocketConnection}; +use crate::websoket_connection::{FrameType, WebsocketRead, WebsocketWrite}; use crate::{ request::{Connection, ServerPath}, response::{Response, ResponseCode, ResponseHeader}, }; +use tokio::sync; + #[tokio::main] async fn main() -> tokio::io::Result<()> { let listener = TcpListener::bind("127.0.0.1:8080").await?; + let (sender, _) = sync::broadcast::channel(16); loop { let (stream, _) = listener.accept().await?; - - tokio::spawn(handle_connection(stream)); + let receiver = sender.subscribe(); + let sender = sender.clone(); + tokio::spawn(handle_connection(stream, receiver, sender)); } } -async fn handle_connection(stream: TcpStream) -> tokio::io::Result<()> { +async fn handle_connection( + stream: TcpStream, + receiver: sync::broadcast::Receiver, + sender: sync::broadcast::Sender, +) -> tokio::io::Result<()> { if let Some(ws) = handle_http_connection(stream).await? { - handle_websocket(ws).await? + handle_websocket(ws, receiver, sender).await? } Ok(()) } async fn handle_http_connection( mut stream: TcpStream, -) -> tokio::io::Result> { +) -> tokio::io::Result> { let mut timeout = 500; loop { let req = match time::timeout( @@ -63,7 +71,7 @@ async fn handle_http_connection( Err(_) => Response::new().with_code(ResponseCode::NotFound), } } - ["websocket"] => match WebsocketConnection::initialize_connection(req, stream).await { + ["websocket"] => match websoket_connection::initialize_connection(req, stream).await { Ok(ws) => { return Ok(Some(ws)); } @@ -91,15 +99,45 @@ async fn handle_http_connection( Ok(None) } -async fn handle_websocket(mut web_socket: WebsocketConnection) -> tokio::io::Result<()> { +async fn handle_websocket( + mut web_socket: (WebsocketRead, WebsocketWrite), + receiver: sync::broadcast::Receiver, + sender: sync::broadcast::Sender, +) -> tokio::io::Result<()> { + tokio::spawn(broadcast_message(web_socket.1, receiver)); + loop { - let message = web_socket.read_next_message().await?; + let message = web_socket.0.read_next_message().await?; if message.frame_type == FrameType::TextFrame { - println!("{}", String::from_utf8_lossy(&message.data)); - web_socket - .send_message(FrameType::TextFrame, "message_received".as_bytes()) - .await?; + let s = String::from_utf8_lossy(&message.data).to_string(); + println!("{}", s); + let _ = sender.send(s); + } + } +} + +enum BroadcastError { + IoError(io::Error), + BroadcastError(sync::broadcast::error::RecvError), +} + +async fn broadcast_message( + mut write: WebsocketWrite, + mut receiver: sync::broadcast::Receiver, +) -> Result<(), BroadcastError> { + loop { + let new_message = match receiver.recv().await { + Ok(s) => s, + Err(e) => return Err(BroadcastError::BroadcastError(e)), + }; + + match write + .send_message(FrameType::TextFrame, new_message.as_bytes()) + .await + { + Ok(()) => {} + Err(e) => return Err(BroadcastError::IoError(e)), } } } diff --git a/src/response.rs b/src/response.rs index 7cb2fc6..9353e70 100644 --- a/src/response.rs +++ b/src/response.rs @@ -25,7 +25,7 @@ impl Response { output.extend_from_slice(format!("Content-Length: {}", self.data.len()).as_bytes()); output.extend_from_slice(b"\r\n\r\n"); output.extend_from_slice(&self.data); - } else { + } else if !self.headers.is_empty() { output.extend_from_slice(b"\r\n"); } diff --git a/src/websoket_connection.rs b/src/websoket_connection.rs index de2be4e..85eaad3 100644 --- a/src/websoket_connection.rs +++ b/src/websoket_connection.rs @@ -3,14 +3,21 @@ use crate::{ response::Response, }; -use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; +use tokio::{ + io::{self, AsyncReadExt, AsyncWriteExt}, + net::tcp::{OwnedReadHalf, OwnedWriteHalf}, +}; use base64::prelude::*; use sha1::{Digest, Sha1}; -pub struct WebsocketConnection { - stream: TcpStream, +pub struct WebsocketRead { + read: OwnedReadHalf, +} + +pub struct WebsocketWrite { + write: OwnedWriteHalf, } struct DataBlock { @@ -41,11 +48,10 @@ pub enum FrameType { OtherNonControl(u8), } -impl WebsocketConnection { +impl WebsocketWrite { pub async fn send_message(&mut self, frame_type: FrameType, data: &[u8]) -> io::Result<()> { - let mut header = Vec::with_capacity(14); // Max header size for 127-length payload + let mut header = Vec::with_capacity(14); - // First byte: FIN (1) + RSV1-3 (000) + opcode let opcode = match frame_type { FrameType::TextFrame => 0x1, FrameType::BinaryFrame => 0x2, @@ -56,7 +62,6 @@ impl WebsocketConnection { }; header.push(0b1000_0000 | opcode); // FIN = 1 - // Second byte: MASK bit = 0 (server -> client frames are NOT masked) let payload_len = data.len(); if payload_len < 126 { header.push(payload_len as u8); @@ -68,14 +73,15 @@ impl WebsocketConnection { header.extend_from_slice(&(payload_len as u64).to_be_bytes()); } - // Send header + payload - self.stream.write_all(&header).await?; - self.stream.write_all(data).await?; - self.stream.flush().await?; + self.write.write_all(&header).await?; + self.write.write_all(data).await?; + self.write.flush().await?; Ok(()) } +} +impl WebsocketRead { pub async fn read_next_message(&mut self) -> io::Result { let first_line = self.parse_single_block().await?; @@ -114,7 +120,7 @@ impl WebsocketConnection { async fn parse_single_block(&mut self) -> io::Result { let mut first_line: [u8; 2] = [0; 2]; - self.stream.read_exact(&mut first_line).await?; + self.read.read_exact(&mut first_line).await?; let get_bool = |index: u8, byte: u8| -> bool { byte & (1 << index) != 0 }; @@ -140,19 +146,15 @@ impl WebsocketConnection { let mask = get_bool(7, first_line[1]); let length = match first_line[1] & 0b01111111 { - 126 => self.stream.read_u16().await? as u64, - 127 => self.stream.read_u64().await?, + 126 => self.read.read_u16().await? as u64, + 127 => self.read.read_u64().await?, other => other as u64, }; - let masking_key = if mask { - self.stream.read_u32().await? - } else { - 0 - }; + let masking_key = if mask { self.read.read_u32().await? } else { 0 }; let mut message_data = vec![0u8; length as usize]; - self.stream.read_exact(&mut message_data).await?; + self.read.read_exact(&mut message_data).await?; if mask { Self::unmask_block(&mut message_data, masking_key); @@ -167,74 +169,78 @@ impl WebsocketConnection { data: message_data, }) } +} +pub async fn initialize_connection( + req: Request, + mut stream: TcpStream, +) -> tokio::io::Result<(WebsocketRead, WebsocketWrite)> { + let (mut upgrade, mut connection, mut key_exists) = (false, false, false); + let mut key_val: Box = "".into(); - pub async fn initialize_connection( - req: Request, - mut stream: TcpStream, - ) -> tokio::io::Result { - let (mut upgrade, mut connection, mut key_exists) = (false, false, false); - let mut key_val: Box = "".into(); - - for i in req.headers { - match i { - RequestHeader::Upgrade(upgrad) => { - if let Some(upg) = upgrad.first() - && upg.protocol == Protocol::Websocket - { - upgrade = true; - } + for i in req.headers { + match i { + RequestHeader::Upgrade(upgrad) => { + if let Some(upg) = upgrad.first() + && upg.protocol == Protocol::Websocket + { + upgrade = true; } - RequestHeader::Connection(con) => { - if con == Connection::Upgrade { - connection = true; - } else if let Connection::Other(c) = con - && c.contains("Upgrade") - { - connection = true; - } - } - RequestHeader::Other { name, value } => { - if name == "Sec-WebSocket-Key".into() { - key_val = value.clone(); - key_exists = true; - } - } - _ => (), } - } - - if upgrade && connection && key_exists { - let magic_val = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - let mut hasher = Sha1::new(); - hasher.update(key_val.as_bytes()); - hasher.update(magic_val); - - let result = hasher.finalize(); - let result = BASE64_STANDARD.encode(result); - - let rep = Response::new() - .with_code(crate::response::ResponseCode::SwitchingProtocols) - .with_header(crate::response::ResponseHeader::Upgrade(Upgrade { - protocol: Protocol::Websocket, - version: None, - })) - .with_header(crate::response::ResponseHeader::Connection( - Connection::Upgrade, - )) - .with_header(crate::response::ResponseHeader::Other { - header_name: "Sec-WebSocket-Accept".into(), - header_value: result.into(), - }); - rep.respond(&mut stream).await?; - - Ok(Self { stream }) - } else { - Response::new() - .with_code(crate::response::ResponseCode::BadRequest) - .respond(&mut stream) - .await?; - stream.flush().await?; - Err(io::Error::new(io::ErrorKind::InvalidData, "Wrong request")) + RequestHeader::Connection(con) => { + if con == Connection::Upgrade { + connection = true; + } else if let Connection::Other(c) = con + && c.contains("Upgrade") + { + connection = true; + } + } + RequestHeader::Other { name, value } => { + if name == "Sec-WebSocket-Key".into() { + key_val = value.clone(); + key_exists = true; + } + } + _ => (), } } + + if upgrade && connection && key_exists { + let magic_val = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut hasher = Sha1::new(); + hasher.update(key_val.as_bytes()); + hasher.update(magic_val); + + let result = hasher.finalize(); + let result = BASE64_STANDARD.encode(result); + + let rep = Response::new() + .with_code(crate::response::ResponseCode::SwitchingProtocols) + .with_header(crate::response::ResponseHeader::Upgrade(Upgrade { + protocol: Protocol::Websocket, + version: None, + })) + .with_header(crate::response::ResponseHeader::Connection( + Connection::Upgrade, + )) + .with_header(crate::response::ResponseHeader::Other { + header_name: "Sec-WebSocket-Accept".into(), + header_value: result.into(), + }); + rep.respond(&mut stream).await?; + + let (read_halve, write_halve) = stream.into_split(); + + Ok(( + WebsocketRead { read: read_halve }, + WebsocketWrite { write: write_halve }, + )) + } else { + Response::new() + .with_code(crate::response::ResponseCode::BadRequest) + .respond(&mut stream) + .await?; + stream.flush().await?; + Err(io::Error::new(io::ErrorKind::InvalidData, "Wrong request")) + } }