use crate::{ request::{Connection, Protocol, Request, RequestHeader, Upgrade}, response::Response, }; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use base64::prelude::*; use sha1::{Digest, Sha1}; pub struct WebsocketConnection { stream: TcpStream, } struct DataBlock { is_final: bool, e1: bool, e2: bool, e3: bool, message_type: FrameType, data: Vec, } pub struct DataFrame { pub frame_type: FrameType, pub data: Box<[u8]>, } #[derive(PartialEq, Eq)] pub enum FrameType { Continuation, TextFrame, BinaryFrame, ConnectionClose, Ping, Pong, OtherControl(u8), OtherNonControl(u8), } impl WebsocketConnection { pub async fn send_message(&mut self, frame_type: FrameType, data: &[u8]) -> io::Result<()> { let mut header = Vec::with_capacity(14); // Max header size for 127-length payload // First byte: FIN (1) + RSV1-3 (000) + opcode let opcode = match frame_type { FrameType::TextFrame => 0x1, FrameType::BinaryFrame => 0x2, FrameType::Ping => 0x9, FrameType::Pong => 0xA, FrameType::ConnectionClose => 0x8, _ => panic!("No other type should by passed to this function"), }; header.push(0b1000_0000 | opcode); // FIN = 1 // Second byte: MASK bit = 0 (server -> client frames are NOT masked) let payload_len = data.len(); if payload_len < 126 { header.push(payload_len as u8); } else if payload_len <= u16::MAX as usize { header.push(126); header.extend_from_slice(&(payload_len as u16).to_be_bytes()); } else { header.push(127); header.extend_from_slice(&(payload_len as u64).to_be_bytes()); } // Send header + payload self.stream.write_all(&header).await?; self.stream.write_all(data).await?; self.stream.flush().await?; Ok(()) } pub async fn read_next_message(&mut self) -> io::Result { let first_line = self.parse_single_block().await?; let mut data = first_line.data; let frame_type = first_line.message_type; if !first_line.is_final { let mut current_line = self.parse_single_block().await?; while !current_line.is_final { if current_line.message_type != FrameType::Continuation { return Err(io::Error::new( io::ErrorKind::InvalidInput, "That is not how websocket works!!!", )); } data.extend_from_slice(¤t_line.data); current_line = self.parse_single_block().await?; } data.extend_from_slice(¤t_line.data); } Ok(DataFrame { frame_type, data: data.into_boxed_slice(), }) } fn unmask_block(data: &mut [u8], mask: u32) { let mask_bytes = mask.to_be_bytes(); for (i, e) in data.iter_mut().enumerate() { *e ^= mask_bytes[i % 4]; } } async fn parse_single_block(&mut self) -> io::Result { let mut first_line: [u8; 2] = [0; 2]; self.stream.read_exact(&mut first_line).await?; let get_bool = |index: u8, byte: u8| -> bool { byte & (1 << index) != 0 }; let is_final = get_bool(7, first_line[0]); let extension_bit_1 = get_bool(6, first_line[0]); let extension_bit_2 = get_bool(5, first_line[0]); let extension_bit_3 = get_bool(4, first_line[0]); let message_type = match first_line[0] & 0b00001111 { 0x0 => FrameType::Continuation, 0x1 => FrameType::TextFrame, 0x2 => FrameType::BinaryFrame, 0x8 => FrameType::ConnectionClose, 0x9 => FrameType::Ping, 0xA => FrameType::Pong, non_control if (0x3..=7).contains(&non_control) => { FrameType::OtherNonControl(non_control) } control => FrameType::OtherControl(control), }; let mask = get_bool(7, first_line[1]); let length = match first_line[1] & 0b01111111 { 126 => self.stream.read_u16().await? as u64, 127 => self.stream.read_u64().await?, other => other as u64, }; let masking_key = if mask { self.stream.read_u32().await? } else { 0 }; let mut message_data = vec![0u8; length as usize]; self.stream.read_exact(&mut message_data).await?; if mask { Self::unmask_block(&mut message_data, masking_key); } Ok(DataBlock { is_final, e1: extension_bit_1, e2: extension_bit_2, e3: extension_bit_3, message_type, data: message_data, }) } pub async fn initialize_connection( req: Request, mut stream: TcpStream, ) -> tokio::io::Result { let (mut upgrade, mut connection, mut key_exists) = (false, false, false); let mut key_val: Box = "".into(); for i in req.headers { match i { RequestHeader::Upgrade(upgrad) => { if let Some(upg) = upgrad.first() && upg.protocol == Protocol::Websocket { upgrade = true; } } RequestHeader::Connection(con) => { if con == Connection::Upgrade { connection = true; } else if let Connection::Other(c) = con && c.contains("Upgrade") { connection = true; } } RequestHeader::Other { name, value } => { if name == "Sec-WebSocket-Key".into() { key_val = value.clone(); key_exists = true; } } _ => (), } } if upgrade && connection && key_exists { let magic_val = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; let mut hasher = Sha1::new(); hasher.update(key_val.as_bytes()); hasher.update(magic_val); let result = hasher.finalize(); let result = BASE64_STANDARD.encode(result); let rep = Response::new() .with_code(crate::response::ResponseCode::SwitchingProtocols) .with_header(crate::response::ResponseHeader::Upgrade(Upgrade { protocol: Protocol::Websocket, version: None, })) .with_header(crate::response::ResponseHeader::Connection( Connection::Upgrade, )) .with_header(crate::response::ResponseHeader::Other { header_name: "Sec-WebSocket-Accept".into(), header_value: result.into(), }); rep.respond(&mut stream).await?; Ok(Self { stream }) } else { Response::new() .with_code(crate::response::ResponseCode::BadRequest) .respond(&mut stream) .await?; stream.flush().await?; Err(io::Error::new(io::ErrorKind::InvalidData, "Wrong request")) } } }