Initial commit.
This commit is contained in:
25
mlc-server/Cargo.toml
Normal file
25
mlc-server/Cargo.toml
Normal file
@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "mlc-server"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
mls-rs.workspace = true
|
||||
mls-rs-core.workspace = true
|
||||
mls-rs-crypto-openssl.workspace = true
|
||||
anyhow.workspace = true
|
||||
reqwest.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde_json.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
|
||||
clap = { version = "^4", features = ["derive", "env"] }
|
||||
tokio = { version = "1.44.1", features = ["full"] }
|
||||
sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio"] }
|
||||
dotenv = "0.15.0"
|
||||
axum = { version = "0.8.1", features = ["ws"] }
|
||||
rustc-hash = { version = "2.1.1", features = ["rand"] }
|
||||
uuid = { version = "1.16.0", features = ["v4"] }
|
||||
futures = "0.3.31"
|
||||
|
0
mlc-server/Dockerfile
Normal file
0
mlc-server/Dockerfile
Normal file
5
mlc-server/build.rs
Normal file
5
mlc-server/build.rs
Normal file
@ -0,0 +1,5 @@
|
||||
// generated by `sqlx migrate build-script`
|
||||
fn main() {
|
||||
// trigger recompilation when a new migration is added
|
||||
println!("cargo:rerun-if-changed=migrations");
|
||||
}
|
6
mlc-server/migrations/20250323085527_init.down.sql
Normal file
6
mlc-server/migrations/20250323085527_init.down.sql
Normal file
@ -0,0 +1,6 @@
|
||||
DROP TABLE `mlc_keypackage`;
|
||||
DROP TABLE `mlc_group_participant`;
|
||||
DROP TABLE `mlc_group`;
|
||||
DROP TABLE `mlc_device`;
|
||||
DROP TABLE `mlc_user`;
|
||||
DROP TABLE `mlc_inbox`;
|
40
mlc-server/migrations/20250323085527_init.up.sql
Normal file
40
mlc-server/migrations/20250323085527_init.up.sql
Normal file
@ -0,0 +1,40 @@
|
||||
CREATE TABLE mlc_user(
|
||||
username TEXT PRIMARY KEY,
|
||||
password TEXT,
|
||||
email TEXT
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE mlc_device(
|
||||
id TEXT,
|
||||
username TEXT,
|
||||
PRIMARY KEY (id, username),
|
||||
FOREIGN KEY (username) REFERENCES mlc_user(username)
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE mlc_group(id TEXT PRIMARY KEY) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE mlc_group_participant(
|
||||
id TEXT PRIMARY KEY,
|
||||
group_id BIGINT,
|
||||
device_id TEXT,
|
||||
username TEXT,
|
||||
FOREIGN KEY (group_id) REFERENCES mlc_group(id),
|
||||
FOREIGN KEY (device_id, username) REFERENCES mlc_device(id, username)
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE mlc_keypackage(
|
||||
hash_ref BLOB PRIMARY KEY,
|
||||
device_id TEXT,
|
||||
username TEXT,
|
||||
expiration BIGINT,
|
||||
given BOOLEAN,
|
||||
data BLOB,
|
||||
FOREIGN KEY (device_id, username) REFERENCES mlc_device(id, username)
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE mlc_inbox(
|
||||
id TEXT PRIMARY KEY,
|
||||
device_id TEXT,
|
||||
username TEXT,
|
||||
FOREIGN KEY (device_id, username) REFERENCES mlc_device(id, username)
|
||||
) WITHOUT ROWID;
|
14
mlc-server/src/args.rs
Normal file
14
mlc-server/src/args.rs
Normal file
@ -0,0 +1,14 @@
|
||||
#[derive(Debug, Clone, clap::Parser)]
|
||||
pub struct MlcConfig {
|
||||
#[arg(long = "port", default_value = "8080", env = "MLC_SERVER_PORT")]
|
||||
pub server_port: u16,
|
||||
#[arg(long = "host", default_value = "0.0.0.0", env = "MLC_SERVER_HOST")]
|
||||
pub server_host: String,
|
||||
#[arg(long, short, default_value = "mlc.sqlite3", env = "MLC_DB_PATH")]
|
||||
pub db_path: String,
|
||||
#[arg(long, short, default_value = "secret", env = "MLC_JWT_SECRET")]
|
||||
pub jwt_secret: String,
|
||||
|
||||
#[arg(long, short, default_value = "100", env = "MLC_PER_USER_MSG_BUFFER")]
|
||||
pub per_user_msg_buffer: usize,
|
||||
}
|
0
mlc-server/src/daos/mod.rs
Normal file
0
mlc-server/src/daos/mod.rs
Normal file
0
mlc-server/src/jwt.rs
Normal file
0
mlc-server/src/jwt.rs
Normal file
163
mlc-server/src/main.rs
Normal file
163
mlc-server/src/main.rs
Normal file
@ -0,0 +1,163 @@
|
||||
use std::{fs::OpenOptions, path::Path};
|
||||
|
||||
use axum::{
|
||||
Extension,
|
||||
extract::{State, ws::Message},
|
||||
http::HeaderMap,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use clap::Parser;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use mls_rs::mls_rs_codec::MlsDecode;
|
||||
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
|
||||
use state::MlcState;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::level_filters::LevelFilter;
|
||||
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
|
||||
mod args;
|
||||
mod daos;
|
||||
mod routes;
|
||||
mod services;
|
||||
mod state;
|
||||
|
||||
async fn start() -> anyhow::Result<()> {
|
||||
let config = args::MlcConfig::parse();
|
||||
// Get the database URL from environment variables
|
||||
{
|
||||
if let Some(parent) = Path::new(&config.db_path).parent() {
|
||||
std::fs::create_dir_all(&parent)?;
|
||||
}
|
||||
OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(&config.db_path)?;
|
||||
}
|
||||
tracing::info!("Database path: {}", config.db_path);
|
||||
let sqlite_pool = SqlitePoolOptions::new()
|
||||
.max_connections(10)
|
||||
.min_connections(4)
|
||||
.max_lifetime(None)
|
||||
.connect(&format!("sqlite:{}", config.db_path))
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
// Run migrations
|
||||
sqlx::migrate!("./migrations")
|
||||
.run(&sqlite_pool)
|
||||
.await
|
||||
.expect("Failed to run migrations");
|
||||
|
||||
let state = state::MlcState {
|
||||
sqlite_pool,
|
||||
ws_manager: services::WSManager::default(),
|
||||
config: config.clone(),
|
||||
};
|
||||
|
||||
let app = axum::Router::new()
|
||||
.route("/ws", axum::routing::get(ws_handler))
|
||||
.with_state(state);
|
||||
|
||||
let listener = TcpListener::bind((config.server_host, config.server_port)).await?;
|
||||
tracing::info!("Listening on: {}", listener.local_addr()?);
|
||||
axum::serve(listener, app).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: axum::extract::ws::WebSocketUpgrade,
|
||||
headers: HeaderMap,
|
||||
State(db_pool): State<MlcState>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| {
|
||||
let username = headers.get("username").unwrap().to_str().unwrap();
|
||||
let (ws_writer, ws_reader) = socket.split();
|
||||
handle_messages(ws_writer, ws_reader, username.to_string(), db_pool)
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_messages<W, R>(
|
||||
mut ws_writer: W,
|
||||
mut ws_reader: R,
|
||||
username: String,
|
||||
state: MlcState,
|
||||
) where
|
||||
W: futures::Sink<Message> + Unpin,
|
||||
R: futures::Stream<Item = Result<Message, axum::Error>> + Unpin,
|
||||
<W as futures::Sink<Message>>::Error: std::fmt::Display,
|
||||
{
|
||||
let (messages_tx, mut messages_rx) =
|
||||
tokio::sync::mpsc::channel::<String>(state.config.per_user_msg_buffer);
|
||||
let ws_id = state.ws_manager.add_client(&username, messages_tx).await;
|
||||
tracing::Span::current().record("ws_id", ws_id);
|
||||
let mut messages = vec![];
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = messages_rx.recv() => {
|
||||
if let Some(msg) = msg {
|
||||
ws_writer.send(Message::Text(msg.into())).await.ok();
|
||||
}else{
|
||||
tracing::debug!("Messages channel has been closed");
|
||||
break;
|
||||
}
|
||||
},
|
||||
msg = ws_reader.next() => {
|
||||
let Some(msg_res) = msg else {
|
||||
tracing::debug!("Websocket connection has been closed");
|
||||
break;
|
||||
};
|
||||
let Ok(msg) = msg_res else {
|
||||
tracing::warn!("Cannot receive message from client: {}", msg_res.err().unwrap());
|
||||
break;
|
||||
};
|
||||
match msg {
|
||||
Message::Close(_) => {
|
||||
tracing::debug!("Client has closed the connection");
|
||||
break;
|
||||
}
|
||||
// Message::Ping(bytes) => {ws_writer.send(Message::Pong(bytes)).await.ok();},
|
||||
Message::Text(utf8_bytes) => {
|
||||
tracing::info!("Received message from ws: {:?}", utf8_bytes);
|
||||
if let Err(err) = ws_writer.send(Message::Text(utf8_bytes)).await {
|
||||
tracing::error!("Failed to send message back: {}", err);
|
||||
break;
|
||||
}
|
||||
},
|
||||
Message::Binary(bytes) => {
|
||||
// bytes.
|
||||
let mut raw_bytes = bytes.as_ref();
|
||||
while raw_bytes.len() > 0 {
|
||||
if let Ok(msg) = mls_rs::MlsMessage::mls_decode(&mut raw_bytes){
|
||||
messages.push(msg);
|
||||
} else {break}
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("{:?}", messages);
|
||||
messages.truncate(0);
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Load environment variables from .env file
|
||||
dotenv::dotenv().ok();
|
||||
FmtSubscriber::builder()
|
||||
.with_env_filter(
|
||||
EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env_lossy(),
|
||||
)
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
// Start the HTTP server
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
rt.block_on(start())
|
||||
}
|
0
mlc-server/src/routes/mod.rs
Normal file
0
mlc-server/src/routes/mod.rs
Normal file
3
mlc-server/src/services/mod.rs
Normal file
3
mlc-server/src/services/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
mod ws_manager;
|
||||
|
||||
pub use ws_manager::WSManager;
|
148
mlc-server/src/services/ws_manager.rs
Normal file
148
mlc-server/src/services/ws_manager.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use tokio::sync::{RwLock, mpsc::Sender};
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct WSManager {
|
||||
connected_clients: Arc<RwLock<FxHashMap<String, Vec<Sender<String>>>>>,
|
||||
}
|
||||
|
||||
impl WSManager {
|
||||
#[tracing::instrument(level = "debug", skip(self, sender))]
|
||||
pub async fn add_client(&self, user_id: &str, sender: Sender<String>) -> usize {
|
||||
let mut connected_clients = self.connected_clients.write().await;
|
||||
let senders = connected_clients
|
||||
.entry(user_id.to_string())
|
||||
.or_insert_with(Vec::new);
|
||||
senders.push(sender);
|
||||
// Minus one, because we use zero-based indexing.
|
||||
senders.len() - 1
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn send_message(&self, user_id: &str, message: &str) -> bool {
|
||||
let mut dead_sockets = vec![];
|
||||
let mut message_sent = false;
|
||||
let user_id_str = user_id.to_string();
|
||||
// Here we are defining a scope to drop the read lock as soon as we are done with it.
|
||||
// Because later we will need to acquire a write lock on the same data.
|
||||
// If we won't drop the read lock, we will get a deadlock.
|
||||
{
|
||||
let connected_clients = self.connected_clients.read().await;
|
||||
if let Some(senders) = connected_clients.get(&user_id_str) {
|
||||
for (id, sender) in senders.iter().enumerate() {
|
||||
if sender.send(message.to_string()).await.is_err() {
|
||||
dead_sockets.push(id);
|
||||
}
|
||||
}
|
||||
// The delivery was successful only if we have
|
||||
// sent the message to at least one socket.
|
||||
if dead_sockets.len() < senders.len() {
|
||||
message_sent = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If all sockets that are connected to the user are running fine, we can return.
|
||||
if dead_sockets.is_empty() {
|
||||
return message_sent;
|
||||
}
|
||||
|
||||
self.remove_clients(user_id, dead_sockets).await;
|
||||
|
||||
message_sent
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn remove_clients(&self, user_id: &str, ws_ids: Vec<usize>) {
|
||||
let mut connected_clients = self.connected_clients.write().await;
|
||||
let mut senders_left = 0;
|
||||
if let Some(senders) = connected_clients.get_mut(&user_id.to_string()) {
|
||||
for id in ws_ids {
|
||||
if id >= senders.len() {
|
||||
continue;
|
||||
}
|
||||
senders.swap_remove(id);
|
||||
}
|
||||
senders_left = senders.len();
|
||||
}
|
||||
if senders_left == 0 {
|
||||
connected_clients.remove(&user_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[tokio::test]
|
||||
async fn add_client() {
|
||||
let ws_manager = super::WSManager::default();
|
||||
let (sender, _) = tokio::sync::mpsc::channel(1);
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||
assert_eq!(client_id, 0);
|
||||
let connected_clients = ws_manager.connected_clients.read().await;
|
||||
let sockets = connected_clients.get(&uid.to_string()).unwrap();
|
||||
assert_eq!(sockets.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sending_message() {
|
||||
let ws_manager = super::WSManager::default();
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||
assert_eq!(client_id, 0);
|
||||
ws_manager.send_message(&uid, "test").await;
|
||||
let message = receiver.recv().await;
|
||||
assert_eq!(message, Some("test".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_clients() {
|
||||
let ws_manager = super::WSManager::default();
|
||||
let mut receivers = vec![];
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
let num_clients = 10;
|
||||
for i in 0..num_clients {
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(1);
|
||||
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||
assert_eq!(client_id, i);
|
||||
receivers.push(receiver);
|
||||
}
|
||||
ws_manager.send_message(&uid, "test").await;
|
||||
for mut receiver in receivers {
|
||||
let message = receiver.recv().await;
|
||||
assert_eq!(message, Some("test".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_client() {
|
||||
let ws_manager = super::WSManager::default();
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||
assert_eq!(client_id, 0);
|
||||
ws_manager.remove_clients(&uid, vec![0, 1, 2]).await;
|
||||
ws_manager.send_message(&uid, "test").await;
|
||||
let message = receiver.recv().await;
|
||||
assert_eq!(message, None);
|
||||
let all_clients = ws_manager.connected_clients.read().await;
|
||||
assert_eq!(all_clients.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_dead_client() {
|
||||
let ws_manager = super::WSManager::default();
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||
receiver.close();
|
||||
assert_eq!(client_id, 0);
|
||||
let delivered = ws_manager.send_message(&uid, "test").await;
|
||||
assert_eq!(delivered, false);
|
||||
let all_clients = ws_manager.connected_clients.read().await;
|
||||
assert_eq!(all_clients.len(), 0);
|
||||
}
|
||||
}
|
10
mlc-server/src/state.rs
Normal file
10
mlc-server/src/state.rs
Normal file
@ -0,0 +1,10 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::{args::MlcConfig, services::WSManager};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MlcState {
|
||||
pub sqlite_pool: SqlitePool,
|
||||
pub ws_manager: WSManager,
|
||||
pub config: MlcConfig,
|
||||
}
|
Reference in New Issue
Block a user