mod request; mod response; mod shared_enums; mod websoket_connection; use std::time::Duration; use std::{path::Path, str::FromStr}; use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::time; use crate::websoket_connection::{FrameType, WebsocketRead, WebsocketWrite}; use crate::{ request::{Connection, ServerPath}, response::{Response, ResponseCode, ResponseHeader}, }; use tokio::sync; #[tokio::main] async fn main() -> tokio::io::Result<()> { let listener = TcpListener::bind("127.0.0.1:8080").await?; let (sender, _) = sync::broadcast::channel(16); loop { let (stream, _) = listener.accept().await?; let receiver = sender.subscribe(); let sender = sender.clone(); tokio::spawn(handle_connection(stream, receiver, sender)); } } async fn handle_connection( stream: TcpStream, receiver: sync::broadcast::Receiver, sender: sync::broadcast::Sender, ) -> tokio::io::Result<()> { if let Some(ws) = handle_http_connection(stream).await? { handle_websocket(ws, receiver, sender).await? } Ok(()) } async fn handle_http_connection( mut stream: TcpStream, ) -> tokio::io::Result> { let timeout = 50; loop { let req = match time::timeout( Duration::from_millis(timeout), request::Request::from_bufreader(&mut stream), ) .await { Ok(Ok(r)) => r, Ok(Err(e)) => { println!("Wrong request: {e}"); break; } Err(_) => { break; } }; let matchable = req.path.to_matchable(); let response = match matchable.as_slice() { ["public", file] => { match Response::from_file(Path::new(format!("./public/{file}").as_str())) { Ok(resp) => resp, Err(_) => Response::new().with_code(ResponseCode::NotFound), } } ["websocket"] => match websoket_connection::initialize_connection(req, stream).await { Ok(ws) => { return Ok(Some(ws)); } Err(e) => { return Err(e); } }, [] => Response::new() .with_code(ResponseCode::PermanentRedirect) .with_header(ResponseHeader::Connection(Connection::KeepAlive)) .with_header(ResponseHeader::Location( ServerPath::from_str("/public/index.html").unwrap(), )), _ => Response::new().with_code(ResponseCode::NotFound), }; response.respond(&mut stream).await?; stream.flush().await?; if req.headers.contains(&request::RequestHeader::Connection( request::Connection::Close, )) || !req.headers.contains(&request::RequestHeader::Connection( request::Connection::KeepAlive, )) { break; } } Ok(None) } async fn handle_websocket( mut web_socket: (WebsocketRead, WebsocketWrite), receiver: sync::broadcast::Receiver, sender: sync::broadcast::Sender, ) -> tokio::io::Result<()> { tokio::spawn(broadcast_message(web_socket.1, receiver)); loop { let message = web_socket.0.read_next_message().await?; if message.frame_type == FrameType::TextFrame { let s = String::from_utf8_lossy(&message.data).to_string(); println!("{}", s); let _ = sender.send(s); } } } enum BroadcastError { IoError(io::Error), BroadcastError(sync::broadcast::error::RecvError), } async fn broadcast_message( mut write: WebsocketWrite, mut receiver: sync::broadcast::Receiver, ) -> Result<(), BroadcastError> { loop { let new_message = match receiver.recv().await { Ok(s) => s, Err(e) => return Err(BroadcastError::BroadcastError(e)), }; match write .send_message(FrameType::TextFrame, new_message.as_bytes()) .await { Ok(()) => {} Err(e) => return Err(BroadcastError::IoError(e)), } } }