multiplayer-game/src/websoket_connection.rs
2025-11-04 22:41:01 +01:00

240 lines
7.4 KiB
Rust

use crate::{
request::{Connection, Protocol, Request, RequestHeader, Upgrade},
response::Response,
};
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use base64::prelude::*;
use sha1::{Digest, Sha1};
pub struct WebsocketConnection {
stream: TcpStream,
}
struct DataBlock {
is_final: bool,
e1: bool,
e2: bool,
e3: bool,
message_type: FrameType,
data: Vec<u8>,
}
pub struct DataFrame {
pub frame_type: FrameType,
pub data: Box<[u8]>,
}
#[derive(PartialEq, Eq)]
pub enum FrameType {
Continuation,
TextFrame,
BinaryFrame,
ConnectionClose,
Ping,
Pong,
OtherControl(u8),
OtherNonControl(u8),
}
impl WebsocketConnection {
pub async fn send_message(&mut self, frame_type: FrameType, data: &[u8]) -> io::Result<()> {
let mut header = Vec::with_capacity(14); // Max header size for 127-length payload
// First byte: FIN (1) + RSV1-3 (000) + opcode
let opcode = match frame_type {
FrameType::TextFrame => 0x1,
FrameType::BinaryFrame => 0x2,
FrameType::Ping => 0x9,
FrameType::Pong => 0xA,
FrameType::ConnectionClose => 0x8,
_ => panic!("No other type should by passed to this function"),
};
header.push(0b1000_0000 | opcode); // FIN = 1
// Second byte: MASK bit = 0 (server -> client frames are NOT masked)
let payload_len = data.len();
if payload_len < 126 {
header.push(payload_len as u8);
} else if payload_len <= u16::MAX as usize {
header.push(126);
header.extend_from_slice(&(payload_len as u16).to_be_bytes());
} else {
header.push(127);
header.extend_from_slice(&(payload_len as u64).to_be_bytes());
}
// Send header + payload
self.stream.write_all(&header).await?;
self.stream.write_all(data).await?;
self.stream.flush().await?;
Ok(())
}
pub async fn read_next_message(&mut self) -> io::Result<DataFrame> {
let first_line = self.parse_single_block().await?;
let mut data = first_line.data;
let frame_type = first_line.message_type;
if !first_line.is_final {
let mut current_line = self.parse_single_block().await?;
while !current_line.is_final {
if current_line.message_type != FrameType::Continuation {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"That is not how websocket works!!!",
));
}
data.extend_from_slice(&current_line.data);
current_line = self.parse_single_block().await?;
}
data.extend_from_slice(&current_line.data);
}
Ok(DataFrame {
frame_type,
data: data.into_boxed_slice(),
})
}
fn unmask_block(data: &mut [u8], mask: u32) {
let mask_bytes = mask.to_be_bytes();
for (i, e) in data.iter_mut().enumerate() {
*e ^= mask_bytes[i % 4];
}
}
async fn parse_single_block(&mut self) -> io::Result<DataBlock> {
let mut first_line: [u8; 2] = [0; 2];
self.stream.read_exact(&mut first_line).await?;
let get_bool = |index: u8, byte: u8| -> bool { byte & (1 << index) != 0 };
let is_final = get_bool(7, first_line[0]);
let extension_bit_1 = get_bool(6, first_line[0]);
let extension_bit_2 = get_bool(5, first_line[0]);
let extension_bit_3 = get_bool(4, first_line[0]);
let message_type = match first_line[0] & 0b00001111 {
0x0 => FrameType::Continuation,
0x1 => FrameType::TextFrame,
0x2 => FrameType::BinaryFrame,
0x8 => FrameType::ConnectionClose,
0x9 => FrameType::Ping,
0xA => FrameType::Pong,
non_control if (0x3..=7).contains(&non_control) => {
FrameType::OtherNonControl(non_control)
}
control => FrameType::OtherControl(control),
};
let mask = get_bool(7, first_line[1]);
let length = match first_line[1] & 0b01111111 {
126 => self.stream.read_u16().await? as u64,
127 => self.stream.read_u64().await?,
other => other as u64,
};
let masking_key = if mask {
self.stream.read_u32().await?
} else {
0
};
let mut message_data = vec![0u8; length as usize];
self.stream.read_exact(&mut message_data).await?;
if mask {
Self::unmask_block(&mut message_data, masking_key);
}
Ok(DataBlock {
is_final,
e1: extension_bit_1,
e2: extension_bit_2,
e3: extension_bit_3,
message_type,
data: message_data,
})
}
pub async fn initialize_connection(
req: Request,
mut stream: TcpStream,
) -> tokio::io::Result<Self> {
let (mut upgrade, mut connection, mut key_exists) = (false, false, false);
let mut key_val: Box<str> = "".into();
for i in req.headers {
match i {
RequestHeader::Upgrade(upgrad) => {
if let Some(upg) = upgrad.first()
&& upg.protocol == Protocol::Websocket
{
upgrade = true;
}
}
RequestHeader::Connection(con) => {
if con == Connection::Upgrade {
connection = true;
} else if let Connection::Other(c) = con
&& c.contains("Upgrade")
{
connection = true;
}
}
RequestHeader::Other { name, value } => {
if name == "Sec-WebSocket-Key".into() {
key_val = value.clone();
key_exists = true;
}
}
_ => (),
}
}
if upgrade && connection && key_exists {
let magic_val = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let mut hasher = Sha1::new();
hasher.update(key_val.as_bytes());
hasher.update(magic_val);
let result = hasher.finalize();
let result = BASE64_STANDARD.encode(result);
let rep = Response::new()
.with_code(crate::response::ResponseCode::SwitchingProtocols)
.with_header(crate::response::ResponseHeader::Upgrade(Upgrade {
protocol: Protocol::Websocket,
version: None,
}))
.with_header(crate::response::ResponseHeader::Connection(
Connection::Upgrade,
))
.with_header(crate::response::ResponseHeader::Other {
header_name: "Sec-WebSocket-Accept".into(),
header_value: result.into(),
});
rep.respond(&mut stream).await?;
Ok(Self { stream })
} else {
Response::new()
.with_code(crate::response::ResponseCode::BadRequest)
.respond(&mut stream)
.await?;
stream.flush().await?;
Err(io::Error::new(io::ErrorKind::InvalidData, "Wrong request"))
}
}
}