240 lines
7.4 KiB
Rust
240 lines
7.4 KiB
Rust
use crate::{
|
|
request::{Connection, Protocol, Request, RequestHeader, Upgrade},
|
|
response::Response,
|
|
};
|
|
|
|
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::TcpStream;
|
|
|
|
use base64::prelude::*;
|
|
use sha1::{Digest, Sha1};
|
|
|
|
pub struct WebsocketConnection {
|
|
stream: TcpStream,
|
|
}
|
|
|
|
struct DataBlock {
|
|
is_final: bool,
|
|
e1: bool,
|
|
e2: bool,
|
|
e3: bool,
|
|
|
|
message_type: FrameType,
|
|
|
|
data: Vec<u8>,
|
|
}
|
|
|
|
pub struct DataFrame {
|
|
pub frame_type: FrameType,
|
|
pub data: Box<[u8]>,
|
|
}
|
|
|
|
#[derive(PartialEq, Eq)]
|
|
pub enum FrameType {
|
|
Continuation,
|
|
TextFrame,
|
|
BinaryFrame,
|
|
ConnectionClose,
|
|
Ping,
|
|
Pong,
|
|
OtherControl(u8),
|
|
OtherNonControl(u8),
|
|
}
|
|
|
|
impl WebsocketConnection {
|
|
pub async fn send_message(&mut self, frame_type: FrameType, data: &[u8]) -> io::Result<()> {
|
|
let mut header = Vec::with_capacity(14); // Max header size for 127-length payload
|
|
|
|
// First byte: FIN (1) + RSV1-3 (000) + opcode
|
|
let opcode = match frame_type {
|
|
FrameType::TextFrame => 0x1,
|
|
FrameType::BinaryFrame => 0x2,
|
|
FrameType::Ping => 0x9,
|
|
FrameType::Pong => 0xA,
|
|
FrameType::ConnectionClose => 0x8,
|
|
_ => panic!("No other type should by passed to this function"),
|
|
};
|
|
header.push(0b1000_0000 | opcode); // FIN = 1
|
|
|
|
// Second byte: MASK bit = 0 (server -> client frames are NOT masked)
|
|
let payload_len = data.len();
|
|
if payload_len < 126 {
|
|
header.push(payload_len as u8);
|
|
} else if payload_len <= u16::MAX as usize {
|
|
header.push(126);
|
|
header.extend_from_slice(&(payload_len as u16).to_be_bytes());
|
|
} else {
|
|
header.push(127);
|
|
header.extend_from_slice(&(payload_len as u64).to_be_bytes());
|
|
}
|
|
|
|
// Send header + payload
|
|
self.stream.write_all(&header).await?;
|
|
self.stream.write_all(data).await?;
|
|
self.stream.flush().await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn read_next_message(&mut self) -> io::Result<DataFrame> {
|
|
let first_line = self.parse_single_block().await?;
|
|
|
|
let mut data = first_line.data;
|
|
let frame_type = first_line.message_type;
|
|
|
|
if !first_line.is_final {
|
|
let mut current_line = self.parse_single_block().await?;
|
|
|
|
while !current_line.is_final {
|
|
if current_line.message_type != FrameType::Continuation {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"That is not how websocket works!!!",
|
|
));
|
|
}
|
|
data.extend_from_slice(¤t_line.data);
|
|
current_line = self.parse_single_block().await?;
|
|
}
|
|
|
|
data.extend_from_slice(¤t_line.data);
|
|
}
|
|
|
|
Ok(DataFrame {
|
|
frame_type,
|
|
data: data.into_boxed_slice(),
|
|
})
|
|
}
|
|
|
|
fn unmask_block(data: &mut [u8], mask: u32) {
|
|
let mask_bytes = mask.to_be_bytes();
|
|
for (i, e) in data.iter_mut().enumerate() {
|
|
*e ^= mask_bytes[i % 4];
|
|
}
|
|
}
|
|
|
|
async fn parse_single_block(&mut self) -> io::Result<DataBlock> {
|
|
let mut first_line: [u8; 2] = [0; 2];
|
|
self.stream.read_exact(&mut first_line).await?;
|
|
|
|
let get_bool = |index: u8, byte: u8| -> bool { byte & (1 << index) != 0 };
|
|
|
|
let is_final = get_bool(7, first_line[0]);
|
|
let extension_bit_1 = get_bool(6, first_line[0]);
|
|
let extension_bit_2 = get_bool(5, first_line[0]);
|
|
let extension_bit_3 = get_bool(4, first_line[0]);
|
|
|
|
let message_type = match first_line[0] & 0b00001111 {
|
|
0x0 => FrameType::Continuation,
|
|
0x1 => FrameType::TextFrame,
|
|
0x2 => FrameType::BinaryFrame,
|
|
0x8 => FrameType::ConnectionClose,
|
|
0x9 => FrameType::Ping,
|
|
0xA => FrameType::Pong,
|
|
|
|
non_control if (0x3..=7).contains(&non_control) => {
|
|
FrameType::OtherNonControl(non_control)
|
|
}
|
|
control => FrameType::OtherControl(control),
|
|
};
|
|
|
|
let mask = get_bool(7, first_line[1]);
|
|
|
|
let length = match first_line[1] & 0b01111111 {
|
|
126 => self.stream.read_u16().await? as u64,
|
|
127 => self.stream.read_u64().await?,
|
|
other => other as u64,
|
|
};
|
|
|
|
let masking_key = if mask {
|
|
self.stream.read_u32().await?
|
|
} else {
|
|
0
|
|
};
|
|
|
|
let mut message_data = vec![0u8; length as usize];
|
|
self.stream.read_exact(&mut message_data).await?;
|
|
|
|
if mask {
|
|
Self::unmask_block(&mut message_data, masking_key);
|
|
}
|
|
|
|
Ok(DataBlock {
|
|
is_final,
|
|
e1: extension_bit_1,
|
|
e2: extension_bit_2,
|
|
e3: extension_bit_3,
|
|
message_type,
|
|
data: message_data,
|
|
})
|
|
}
|
|
|
|
pub async fn initialize_connection(
|
|
req: Request,
|
|
mut stream: TcpStream,
|
|
) -> tokio::io::Result<Self> {
|
|
let (mut upgrade, mut connection, mut key_exists) = (false, false, false);
|
|
let mut key_val: Box<str> = "".into();
|
|
|
|
for i in req.headers {
|
|
match i {
|
|
RequestHeader::Upgrade(upgrad) => {
|
|
if let Some(upg) = upgrad.first()
|
|
&& upg.protocol == Protocol::Websocket
|
|
{
|
|
upgrade = true;
|
|
}
|
|
}
|
|
RequestHeader::Connection(con) => {
|
|
if con == Connection::Upgrade {
|
|
connection = true;
|
|
} else if let Connection::Other(c) = con
|
|
&& c.contains("Upgrade")
|
|
{
|
|
connection = true;
|
|
}
|
|
}
|
|
RequestHeader::Other { name, value } => {
|
|
if name == "Sec-WebSocket-Key".into() {
|
|
key_val = value.clone();
|
|
key_exists = true;
|
|
}
|
|
}
|
|
_ => (),
|
|
}
|
|
}
|
|
|
|
if upgrade && connection && key_exists {
|
|
let magic_val = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
|
let mut hasher = Sha1::new();
|
|
hasher.update(key_val.as_bytes());
|
|
hasher.update(magic_val);
|
|
|
|
let result = hasher.finalize();
|
|
let result = BASE64_STANDARD.encode(result);
|
|
|
|
let rep = Response::new()
|
|
.with_code(crate::response::ResponseCode::SwitchingProtocols)
|
|
.with_header(crate::response::ResponseHeader::Upgrade(Upgrade {
|
|
protocol: Protocol::Websocket,
|
|
version: None,
|
|
}))
|
|
.with_header(crate::response::ResponseHeader::Connection(
|
|
Connection::Upgrade,
|
|
))
|
|
.with_header(crate::response::ResponseHeader::Other {
|
|
header_name: "Sec-WebSocket-Accept".into(),
|
|
header_value: result.into(),
|
|
});
|
|
rep.respond(&mut stream).await?;
|
|
|
|
Ok(Self { stream })
|
|
} else {
|
|
Response::new()
|
|
.with_code(crate::response::ResponseCode::BadRequest)
|
|
.respond(&mut stream)
|
|
.await?;
|
|
stream.flush().await?;
|
|
Err(io::Error::new(io::ErrorKind::InvalidData, "Wrong request"))
|
|
}
|
|
}
|
|
}
|