Initial commit.
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
/target
|
||||||
|
*.sqlite3*
|
7162
Cargo.lock
generated
Normal file
7162
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
20
Cargo.toml
Normal file
20
Cargo.toml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
[workspace]
|
||||||
|
resolver = "3"
|
||||||
|
members = ["mlc-client", "mlc-server"]
|
||||||
|
|
||||||
|
[workspace.package]
|
||||||
|
version = "0.0.1"
|
||||||
|
edition = "2024"
|
||||||
|
authors = ["s3rius <s3rius@le-memese.com>"]
|
||||||
|
description = "My Little Chat"
|
||||||
|
|
||||||
|
[workspace.dependencies]
|
||||||
|
mls-rs = { version = "^0", features = ["serde"] }
|
||||||
|
mls-rs-core = "^0"
|
||||||
|
mls-rs-crypto-openssl = "^0"
|
||||||
|
anyhow = { version = "^1", features = ["backtrace"] }
|
||||||
|
reqwest = { version = "^0", features = ["json", "multipart"] }
|
||||||
|
serde_json = "^1"
|
||||||
|
thiserror = "^2"
|
||||||
|
tracing = "^0"
|
||||||
|
tracing-subscriber = { version = "^0", features = ["env-filter"] }
|
29
mlc-client/Cargo.toml
Normal file
29
mlc-client/Cargo.toml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
[package]
|
||||||
|
name = "mlc-client"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
description.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
mls-rs.workspace = true
|
||||||
|
mls-rs-core.workspace = true
|
||||||
|
mls-rs-crypto-openssl.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
reqwest.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
|
tracing-subscriber.workspace = true
|
||||||
|
# Crate-specific dependencies
|
||||||
|
rusqlite = { version = "^0", features = ["bundled", "chrono"] }
|
||||||
|
rusqlite_migration = "^1.3.1"
|
||||||
|
chrono = "0.4.40"
|
||||||
|
eframe = "0.31.1"
|
||||||
|
egui = "0.31.1"
|
||||||
|
rfd = "0.15.3"
|
||||||
|
egui_extras = { version = "0.31.1", features = ["image"] }
|
||||||
|
image = "0.25.5"
|
||||||
|
egui_inbox = "0.8.0"
|
||||||
|
directories = "6.0.0"
|
||||||
|
machine-uid = "0.5.3"
|
38
mlc-client/src/ctx.rs
Normal file
38
mlc-client/src/ctx.rs
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use crate::{MlsClient, screens::Screens};
|
||||||
|
|
||||||
|
/// Context which is valid only when logged in.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MlsContext {
|
||||||
|
/// Client for mls operations.
|
||||||
|
/// It has ciphers, crypto provider, etc.
|
||||||
|
/// It is used to create groups, send messages, etc.
|
||||||
|
pub mls_client: MlsClient,
|
||||||
|
/// ID of the user from the server.
|
||||||
|
pub user_id: String,
|
||||||
|
/// Current group which user is in.
|
||||||
|
pub current_group: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClientCTX {
|
||||||
|
pub email: String,
|
||||||
|
pub password: String,
|
||||||
|
pub machine_id: String,
|
||||||
|
pub current_screen: Screens,
|
||||||
|
pub sqlite_connection: Arc<Mutex<rusqlite::Connection>>,
|
||||||
|
pub mls_context: Option<MlsContext>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientCTX {
|
||||||
|
pub fn new(sqlite_connection: Arc<Mutex<rusqlite::Connection>>) -> Self {
|
||||||
|
Self {
|
||||||
|
email: String::new(),
|
||||||
|
password: String::new(),
|
||||||
|
machine_id: String::new(),
|
||||||
|
current_screen: Screens::Login,
|
||||||
|
sqlite_connection,
|
||||||
|
mls_context: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
10
mlc-client/src/daos/error.rs
Normal file
10
mlc-client/src/daos/error.rs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
/// This error exist just to implement the `IntoAnyError` trait for the `MlsCLIDBError` enum.
|
||||||
|
/// This is necessary to convert the error to `anyhow::Error` in
|
||||||
|
/// the error that mls_rs can work with.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum MlsCLIDBError {
|
||||||
|
#[error("Error: {0}")]
|
||||||
|
AnyError(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl mls_rs_core::error::IntoAnyError for MlsCLIDBError {}
|
177
mlc-client/src/daos/group_state.rs
Normal file
177
mlc-client/src/daos/group_state.rs
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use mls_rs::GroupStateStorage;
|
||||||
|
use mls_rs_core::group::{EpochRecord, GroupState};
|
||||||
|
use rusqlite::{Connection, OptionalExtension};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// SQLite Storage for MLS group states.
|
||||||
|
pub struct MlsCliGroupStateStorage {
|
||||||
|
connection: Arc<Mutex<Connection>>,
|
||||||
|
max_epoch_retention: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MlsCliGroupStateStorage {
|
||||||
|
pub fn new(connection: Arc<Mutex<Connection>>) -> MlsCliGroupStateStorage {
|
||||||
|
MlsCliGroupStateStorage {
|
||||||
|
connection,
|
||||||
|
max_epoch_retention: 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all the group ids for groups that are stored.
|
||||||
|
pub fn group_ids(&self) -> anyhow::Result<Vec<Vec<u8>>> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
let mut statement = connection.prepare("SELECT group_id FROM mls_group")?;
|
||||||
|
let res = statement
|
||||||
|
.query_map([], |row| row.get::<_, Vec<u8>>(0))?
|
||||||
|
.try_fold(Vec::new(), |mut ids, id| {
|
||||||
|
ids.push(id?);
|
||||||
|
Ok::<_, anyhow::Error>(ids)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete a group from storage.
|
||||||
|
pub fn delete_group(&self, group_id: &[u8]) -> anyhow::Result<()> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.execute(
|
||||||
|
"DELETE FROM mls_group WHERE group_id = ?",
|
||||||
|
rusqlite::params![group_id],
|
||||||
|
)
|
||||||
|
.map(|_| ())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_epoch_retention(&self) -> u64 {
|
||||||
|
self.max_epoch_retention
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_snapshot_data(&self, group_id: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.query_row(
|
||||||
|
"SELECT snapshot FROM mls_group where group_id = ?",
|
||||||
|
[group_id],
|
||||||
|
|row| row.get::<_, Vec<u8>>(0),
|
||||||
|
)
|
||||||
|
.optional()
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_epoch_data(&self, group_id: &[u8], epoch_id: u64) -> anyhow::Result<Option<Vec<u8>>> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
let res = connection
|
||||||
|
.query_row(
|
||||||
|
"SELECT epoch_data FROM epoch where group_id = ? AND epoch_id = ?",
|
||||||
|
rusqlite::params![group_id, epoch_id],
|
||||||
|
|row| row.get::<_, Vec<u8>>(0),
|
||||||
|
)
|
||||||
|
.optional()?;
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_epoch_id(&self, group_id: &[u8]) -> anyhow::Result<Option<u64>> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.query_row(
|
||||||
|
"SELECT MAX(epoch_id) FROM epoch WHERE group_id = ?",
|
||||||
|
rusqlite::params![group_id],
|
||||||
|
|row| row.get::<_, Option<u64>>(0),
|
||||||
|
)
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_group_state(
|
||||||
|
&self,
|
||||||
|
group_id: &[u8],
|
||||||
|
group_snapshot: Vec<u8>,
|
||||||
|
inserts: Vec<EpochRecord>,
|
||||||
|
updates: Vec<EpochRecord>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let mut max_epoch_id = None;
|
||||||
|
|
||||||
|
let mut connection = self.connection.lock().unwrap();
|
||||||
|
let transaction = connection.transaction()?;
|
||||||
|
|
||||||
|
// Upsert into the group table to set the most recent snapshot
|
||||||
|
transaction.execute(
|
||||||
|
"INSERT INTO mls_group (group_id, snapshot) VALUES (?, ?) ON CONFLICT(group_id) DO UPDATE SET snapshot=excluded.snapshot",
|
||||||
|
rusqlite::params![group_id, group_snapshot],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Insert new epochs as needed
|
||||||
|
for epoch in inserts {
|
||||||
|
max_epoch_id = Some(epoch.id);
|
||||||
|
|
||||||
|
transaction
|
||||||
|
.execute(
|
||||||
|
"INSERT INTO epoch (group_id, epoch_id, epoch_data) VALUES (?, ?, ?)",
|
||||||
|
rusqlite::params![group_id, epoch.id, epoch.data],
|
||||||
|
)
|
||||||
|
.map(|_| ())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update existing epochs as needed
|
||||||
|
for update in updates {
|
||||||
|
max_epoch_id = Some(update.id);
|
||||||
|
transaction
|
||||||
|
.execute(
|
||||||
|
"UPDATE epoch SET epoch_data = ? WHERE group_id = ? AND epoch_id = ?",
|
||||||
|
rusqlite::params![update.data, group_id, update.id],
|
||||||
|
)
|
||||||
|
.map(|_| ())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete old epochs as needed
|
||||||
|
if let Some(max_epoch_id) = max_epoch_id {
|
||||||
|
if max_epoch_id >= self.max_epoch_retention {
|
||||||
|
let delete_under = max_epoch_id - self.max_epoch_retention;
|
||||||
|
|
||||||
|
transaction.execute(
|
||||||
|
"DELETE FROM epoch WHERE group_id = ? AND epoch_id <= ?",
|
||||||
|
rusqlite::params![group_id, delete_under],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the full transaction
|
||||||
|
transaction.commit()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GroupStateStorage for MlsCliGroupStateStorage {
|
||||||
|
type Error = super::error::MlsCLIDBError;
|
||||||
|
|
||||||
|
fn write(
|
||||||
|
&mut self,
|
||||||
|
state: GroupState,
|
||||||
|
inserts: Vec<EpochRecord>,
|
||||||
|
updates: Vec<EpochRecord>,
|
||||||
|
) -> Result<(), Self::Error> {
|
||||||
|
let group_id = state.id;
|
||||||
|
let snapshot_data = state.data;
|
||||||
|
|
||||||
|
self.update_group_state(&group_id, snapshot_data, inserts, updates)
|
||||||
|
.map_err(From::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
|
||||||
|
self.get_snapshot_data(group_id).map_err(From::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
|
||||||
|
self.max_epoch_id(group_id).map_err(From::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> {
|
||||||
|
self.get_epoch_data(group_id, epoch_id).map_err(From::from)
|
||||||
|
}
|
||||||
|
}
|
72
mlc-client/src/daos/groups_mapping.rs
Normal file
72
mlc-client/src/daos/groups_mapping.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use rusqlite::{Connection, OptionalExtension};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct GroupNamesStorage {
|
||||||
|
pub connection: Arc<Mutex<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GroupNamesStorage {
|
||||||
|
pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
|
||||||
|
Self { connection }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_name(&self, chat_id: &str, name: &str) -> anyhow::Result<()> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO mls_cli_group_names(
|
||||||
|
id,
|
||||||
|
name
|
||||||
|
) VALUES (
|
||||||
|
:id,
|
||||||
|
:name
|
||||||
|
) ON CONFLICT(id) DO UPDATE SET name=excluded.name",
|
||||||
|
rusqlite::named_params![
|
||||||
|
":id": chat_id,
|
||||||
|
":name": name
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_name(&self, group_id: &str) -> anyhow::Result<Option<String>> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
conn.query_row(
|
||||||
|
"SELECT name
|
||||||
|
FROM mls_cli_group_names
|
||||||
|
WHERE id = :id",
|
||||||
|
rusqlite::named_params![
|
||||||
|
":id": group_id,
|
||||||
|
],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.optional()
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_id(&self, name: &str) -> anyhow::Result<Option<String>> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
conn.query_row(
|
||||||
|
"SELECT
|
||||||
|
id
|
||||||
|
FROM mls_cli_group_names
|
||||||
|
WHERE name = :name",
|
||||||
|
rusqlite::named_params![
|
||||||
|
":name": name,
|
||||||
|
],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.optional()
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
}
|
118
mlc-client/src/daos/key_packages.rs
Normal file
118
mlc-client/src/daos/key_packages.rs
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
use mls_rs_core::{
|
||||||
|
key_package::{KeyPackageData, KeyPackageStorage},
|
||||||
|
mls_rs_codec::{MlsDecode, MlsEncode},
|
||||||
|
time::MlsTime,
|
||||||
|
};
|
||||||
|
use rusqlite::{Connection, OptionalExtension, params};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// SQLite storage for MLS Key Packages.
|
||||||
|
pub struct MlsCliKeyPackageStorage {
|
||||||
|
connection: Arc<Mutex<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MlsCliKeyPackageStorage {
|
||||||
|
pub fn new(connection: Arc<Mutex<Connection>>) -> MlsCliKeyPackageStorage {
|
||||||
|
MlsCliKeyPackageStorage { connection }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert(&mut self, id: &[u8], key_package: KeyPackageData) -> anyhow::Result<()> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.execute(
|
||||||
|
"INSERT INTO key_package (id, expiration, data) VALUES (?,?,?)",
|
||||||
|
params![id, key_package.expiration, key_package.mls_encode_to_vec()?],
|
||||||
|
)
|
||||||
|
.map(|_| ())
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, id: &[u8]) -> anyhow::Result<Option<KeyPackageData>> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.query_row(
|
||||||
|
"SELECT data FROM key_package WHERE id = ?",
|
||||||
|
params![id],
|
||||||
|
|row| {
|
||||||
|
Ok(
|
||||||
|
KeyPackageData::mls_decode(&mut row.get::<_, Vec<u8>>(0)?.as_slice())
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.optional()
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete a specific key package from storage based on it's id.
|
||||||
|
pub fn delete(&self, id: &[u8]) -> anyhow::Result<()> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.execute("DELETE FROM key_package where id = ?", params![id])
|
||||||
|
.map(|_| ())
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete key packages that are expired based on the current system clock time.
|
||||||
|
pub fn delete_expired(&self) -> anyhow::Result<()> {
|
||||||
|
self.delete_expired_by_time(MlsTime::now().seconds_since_epoch())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete key packages that are expired based on an application provided time in seconds since
|
||||||
|
/// unix epoch.
|
||||||
|
pub fn delete_expired_by_time(&self, time: u64) -> anyhow::Result<()> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.execute(
|
||||||
|
"DELETE FROM key_package where expiration < ?",
|
||||||
|
params![time],
|
||||||
|
)
|
||||||
|
.map(|_| ())
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Total number of key packages held in storage.
|
||||||
|
pub fn count(&self) -> anyhow::Result<usize> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.query_row("SELECT count(*) FROM key_package", params![], |row| {
|
||||||
|
row.get(0)
|
||||||
|
})
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn count_at_time(&self, time: u64) -> anyhow::Result<usize> {
|
||||||
|
self.delete_expired()?;
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.query_row(
|
||||||
|
"SELECT count(*) FROM key_package where expiration >= ?",
|
||||||
|
params![time],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KeyPackageStorage for MlsCliKeyPackageStorage {
|
||||||
|
type Error = super::error::MlsCLIDBError;
|
||||||
|
|
||||||
|
fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
|
||||||
|
self.insert(id.as_slice(), pkg).map_err(From::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
|
||||||
|
self.get(id).map_err(From::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
|
||||||
|
(*self).delete(id).map_err(From::from)
|
||||||
|
}
|
||||||
|
}
|
164
mlc-client/src/daos/messages.rs
Normal file
164
mlc-client/src/daos/messages.rs
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rusqlite::Connection;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MessageStorage {
|
||||||
|
pub connection: Arc<Mutex<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MessageDTO {
|
||||||
|
pub group_id: Vec<u8>,
|
||||||
|
pub sender_id: Vec<u8>,
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MessageStorage {
|
||||||
|
pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
|
||||||
|
MessageStorage { connection }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert_message(
|
||||||
|
&self,
|
||||||
|
chat_id: &[u8],
|
||||||
|
sender_id: &[u8],
|
||||||
|
message: &[u8],
|
||||||
|
epoch_id: u64,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO mls_cli_messages(
|
||||||
|
group_id,
|
||||||
|
sender_id,
|
||||||
|
data,
|
||||||
|
epoch_id
|
||||||
|
) VALUES (
|
||||||
|
:group_id,
|
||||||
|
:sender_id,
|
||||||
|
:data,
|
||||||
|
:epoch_id
|
||||||
|
)",
|
||||||
|
rusqlite::named_params![
|
||||||
|
":group_id": chat_id,
|
||||||
|
":sender_id": sender_id,
|
||||||
|
":data": message,
|
||||||
|
":epoch_id": epoch_id
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn get_messages(
|
||||||
|
&self,
|
||||||
|
group_id: &str,
|
||||||
|
offset: i64,
|
||||||
|
limit: i64,
|
||||||
|
) -> anyhow::Result<Vec<MessageDTO>> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT
|
||||||
|
group_id,
|
||||||
|
sender_id,
|
||||||
|
data,
|
||||||
|
created_at
|
||||||
|
FROM mls_cli_messages
|
||||||
|
WHERE group_id = :group_id
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT :limit
|
||||||
|
OFFSET :offset",
|
||||||
|
)?;
|
||||||
|
let rows = stmt.query_map(
|
||||||
|
rusqlite::named_params![
|
||||||
|
":group_id": group_id.as_bytes(),
|
||||||
|
":offset": offset,
|
||||||
|
":limit": limit
|
||||||
|
],
|
||||||
|
|row| {
|
||||||
|
let group_id = row.get(0)?;
|
||||||
|
let sender_id = row.get(1)?;
|
||||||
|
let data = row.get(2)?;
|
||||||
|
let created_at = row.get(3)?;
|
||||||
|
Ok(MessageDTO {
|
||||||
|
group_id,
|
||||||
|
sender_id,
|
||||||
|
data,
|
||||||
|
created_at,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
let mut res = vec![];
|
||||||
|
for row in rows {
|
||||||
|
res.push(row?);
|
||||||
|
}
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn get_all_messages(&self, offset: i64, limit: i64) -> anyhow::Result<Vec<MessageDTO>> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT
|
||||||
|
group_id,
|
||||||
|
sender_id,
|
||||||
|
data,
|
||||||
|
created_at
|
||||||
|
FROM mls_cli_messages
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT :limit
|
||||||
|
OFFSET :offset",
|
||||||
|
)?;
|
||||||
|
let rows = stmt.query_map(
|
||||||
|
rusqlite::named_params![":offset": offset, ":limit": limit],
|
||||||
|
|row| {
|
||||||
|
let group_id = row.get(0)?;
|
||||||
|
let sender_id = row.get(1)?;
|
||||||
|
let data = row.get(2)?;
|
||||||
|
let created_at = row.get(3)?;
|
||||||
|
Ok(MessageDTO {
|
||||||
|
group_id,
|
||||||
|
sender_id,
|
||||||
|
data,
|
||||||
|
created_at,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
let mut res = vec![];
|
||||||
|
for row in rows {
|
||||||
|
res.push(row?);
|
||||||
|
}
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn message_sent_in_this_epoch(&self, group_id: &str, epoch_id: u64) -> anyhow::Result<u64> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT
|
||||||
|
COUNT(*)
|
||||||
|
FROM mls_cli_messages
|
||||||
|
WHERE group_id = :group_id
|
||||||
|
AND epoch_id = :epoch_id",
|
||||||
|
)?;
|
||||||
|
let count = stmt.query_row(
|
||||||
|
rusqlite::named_params![
|
||||||
|
":group_id": group_id.as_bytes(),
|
||||||
|
":epoch_id": epoch_id
|
||||||
|
],
|
||||||
|
|row| row.get(0),
|
||||||
|
)?;
|
||||||
|
Ok(count)
|
||||||
|
}
|
||||||
|
}
|
14
mlc-client/src/daos/mod.rs
Normal file
14
mlc-client/src/daos/mod.rs
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
mod error;
|
||||||
|
mod group_state;
|
||||||
|
mod groups_mapping;
|
||||||
|
mod key_packages;
|
||||||
|
mod messages;
|
||||||
|
mod pre_shared_keys;
|
||||||
|
mod user;
|
||||||
|
|
||||||
|
pub use group_state::MlsCliGroupStateStorage;
|
||||||
|
pub use groups_mapping::GroupNamesStorage;
|
||||||
|
pub use key_packages::MlsCliKeyPackageStorage;
|
||||||
|
pub use messages::MessageStorage;
|
||||||
|
pub use pre_shared_keys::MlsCliPreSharedKeyStorage;
|
||||||
|
pub use user::UserdataStorage;
|
63
mlc-client/src/daos/pre_shared_keys.rs
Normal file
63
mlc-client/src/daos/pre_shared_keys.rs
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage};
|
||||||
|
use rusqlite::{Connection, OptionalExtension, params};
|
||||||
|
use std::{
|
||||||
|
ops::Deref,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// SQLite storage for MLS pre-shared keys.
|
||||||
|
pub struct MlsCliPreSharedKeyStorage {
|
||||||
|
connection: Arc<Mutex<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MlsCliPreSharedKeyStorage {
|
||||||
|
pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
|
||||||
|
MlsCliPreSharedKeyStorage { connection }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert a pre-shared key into storage.
|
||||||
|
pub fn insert(&self, psk_id: &[u8], psk: &PreSharedKey) -> anyhow::Result<()> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
// Upsert into the database
|
||||||
|
connection
|
||||||
|
.execute(
|
||||||
|
"INSERT INTO psk (psk_id, data) VALUES (?,?) ON CONFLICT(psk_id) DO UPDATE SET data=excluded.data",
|
||||||
|
params![psk_id, psk.deref()],
|
||||||
|
)
|
||||||
|
.map(|_| ()).map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a pre-shared key from storage based on a unique id.
|
||||||
|
pub fn get(&self, psk_id: &[u8]) -> anyhow::Result<Option<PreSharedKey>> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.query_row(
|
||||||
|
"SELECT data FROM psk WHERE psk_id = ?",
|
||||||
|
params![psk_id],
|
||||||
|
|row| Ok(PreSharedKey::new(row.get(0)?)),
|
||||||
|
)
|
||||||
|
.optional()
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete a pre-shared key from storage based on a unique id.
|
||||||
|
pub fn delete(&self, psk_id: &[u8]) -> anyhow::Result<()> {
|
||||||
|
let connection = self.connection.lock().unwrap();
|
||||||
|
|
||||||
|
connection
|
||||||
|
.execute("DELETE FROM psk WHERE psk_id = ?", params![psk_id])
|
||||||
|
.map(|_| ())
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PreSharedKeyStorage for MlsCliPreSharedKeyStorage {
|
||||||
|
type Error = super::error::MlsCLIDBError;
|
||||||
|
|
||||||
|
fn get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
|
||||||
|
self.get(id).map_err(From::from)
|
||||||
|
}
|
||||||
|
}
|
72
mlc-client/src/daos/user.rs
Normal file
72
mlc-client/src/daos/user.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use rusqlite::Connection;
|
||||||
|
use rusqlite::OptionalExtension;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct UserdataStorage {
|
||||||
|
pub connection: Arc<Mutex<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Userdata {
|
||||||
|
pub user_id: String,
|
||||||
|
pub secret_key: Vec<u8>,
|
||||||
|
pub public_key: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserdataStorage {
|
||||||
|
pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
|
||||||
|
Self { connection }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_user(&self) -> anyhow::Result<Option<Userdata>> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
let res = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT user_id, secret_key, public_key FROM mls_cli_userdata",
|
||||||
|
[],
|
||||||
|
|row| {
|
||||||
|
let user_id = row.get(0)?;
|
||||||
|
let secret_key = row.get(1)?;
|
||||||
|
let public_key = row.get(2)?;
|
||||||
|
Ok(Userdata {
|
||||||
|
user_id,
|
||||||
|
secret_key,
|
||||||
|
public_key,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.optional()?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn truncate(&self) -> anyhow::Result<()> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
conn.execute("DELETE FROM mls_cli_userdata", [])?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_user(
|
||||||
|
&self,
|
||||||
|
user_id: &str,
|
||||||
|
public_key: &[u8],
|
||||||
|
secret_key: &[u8],
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Connection mutext is poisoned. It's a critical error. Exiting.");
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO mls_cli_userdata (user_id, secret_key, public_key) VALUES (?1, ?2, ?3)",
|
||||||
|
rusqlite::params![user_id, secret_key, public_key],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
158
mlc-client/src/listener.rs
Normal file
158
mlc-client/src/listener.rs
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::context::ClientContext;
|
||||||
|
use mls_rs::MlsMessage;
|
||||||
|
use mls_rs::group::ReceivedMessage;
|
||||||
|
use mls_rs::time::MlsTime;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
pub fn process_updates(ctx: Arc<RwLock<ClientContext>>) {
|
||||||
|
// Messages not related to any group
|
||||||
|
let mut general_messages: Vec<MlsMessage> = vec![];
|
||||||
|
// Messages related to some groups
|
||||||
|
let mut grouped_updates: HashMap<Vec<u8>, Vec<MlsMessage>> = HashMap::default();
|
||||||
|
let ctx = ctx.read();
|
||||||
|
let Some(mls_ctx) = &ctx.mls else { return };
|
||||||
|
let updates = ctx.ds_client.get_updates().await;
|
||||||
|
|
||||||
|
match updates {
|
||||||
|
Ok(updates) => {
|
||||||
|
for update in updates {
|
||||||
|
if let Some(group_id) = update.group_id().map(|id| id.to_vec()) {
|
||||||
|
grouped_updates.entry(group_id).or_default().push(update);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
general_messages.push(update);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
maybe_printer_print(
|
||||||
|
format!("Cannot receive updates! Reason: {}", err.to_string().red()),
|
||||||
|
printer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !general_messages.is_empty() {
|
||||||
|
for msg in &general_messages {
|
||||||
|
if msg.wire_format() == mls_rs::WireFormat::Welcome {
|
||||||
|
match mls_ctx.mls_client.join_group(None, msg) {
|
||||||
|
Ok((mut group, _)) => {
|
||||||
|
group.write_to_storage().ok();
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
maybe_printer_print(
|
||||||
|
format!("Cannot join group! Reason: {}", err.to_string().red()),
|
||||||
|
printer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !grouped_updates.is_empty() {
|
||||||
|
for (group_id, messages) in &grouped_updates {
|
||||||
|
let Ok(mut group) = mls_ctx.mls_client.load_group(group_id) else {
|
||||||
|
maybe_printer_print(
|
||||||
|
format!(
|
||||||
|
"Cannot load group with id: {:?}",
|
||||||
|
String::from_utf8_lossy(group_id).red().to_string()
|
||||||
|
),
|
||||||
|
printer,
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
match group.process_incoming_message_with_time(message.clone(), MlsTime::now()) {
|
||||||
|
Ok(msg) => match msg {
|
||||||
|
ReceivedMessage::ApplicationMessage(app_msg) => {
|
||||||
|
let Some(member) = group.member_at_index(app_msg.sender_index) else {
|
||||||
|
maybe_printer_print(
|
||||||
|
format!(
|
||||||
|
"Cannot find member with index: {}",
|
||||||
|
app_msg.sender_index
|
||||||
|
),
|
||||||
|
printer,
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let sender_name = member
|
||||||
|
.signing_identity
|
||||||
|
.credential
|
||||||
|
.as_basic()
|
||||||
|
.unwrap()
|
||||||
|
.identifier
|
||||||
|
.as_slice();
|
||||||
|
let content = String::from_utf8_lossy(app_msg.data());
|
||||||
|
let group_name = String::from_utf8_lossy(group_id);
|
||||||
|
ctx.message_storage
|
||||||
|
.insert_message(
|
||||||
|
group_id,
|
||||||
|
sender_name,
|
||||||
|
app_msg.data(),
|
||||||
|
group.context().epoch(),
|
||||||
|
)
|
||||||
|
.ok();
|
||||||
|
eprintln!(
|
||||||
|
"Received a message. group: `{}`; content: `{}`",
|
||||||
|
group_name.yellow(),
|
||||||
|
content.green(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
ReceivedMessage::Proposal(proposal) => {
|
||||||
|
let commit = group
|
||||||
|
.commit_builder()
|
||||||
|
.raw_proposal(proposal.proposal)
|
||||||
|
.build();
|
||||||
|
match commit {
|
||||||
|
Ok(cmt) => {
|
||||||
|
group.apply_pending_commit().ok();
|
||||||
|
ctx.ds_client.post_commit(&cmt).await.ok();
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
maybe_printer_print(
|
||||||
|
format!(
|
||||||
|
"Cannot create commit! Reason: {}",
|
||||||
|
err.to_string().red()
|
||||||
|
),
|
||||||
|
printer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
},
|
||||||
|
Err(err) => match err {
|
||||||
|
// Ignore messages from self
|
||||||
|
mls_rs::error::MlsError::CantProcessMessageFromSelf => {}
|
||||||
|
_ => {
|
||||||
|
maybe_printer_print(
|
||||||
|
format!(
|
||||||
|
"Cannot process message! Reason: {}",
|
||||||
|
err.to_string().red()
|
||||||
|
),
|
||||||
|
printer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
group.write_to_storage().ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Vec::clear(&mut general_messages);
|
||||||
|
HashMap::clear(&mut grouped_updates);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn listen_to_updates(
|
||||||
|
ctx: Arc<RwLock<ClientContext>>,
|
||||||
|
printer: ExternalPrinter<String>,
|
||||||
|
) -> ! {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
process_updates(ctx.clone(), Some(&printer)).await;
|
||||||
|
}
|
||||||
|
}
|
129
mlc-client/src/main.rs
Normal file
129
mlc-client/src/main.rs
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release
|
||||||
|
#![allow(rustdoc::missing_crate_level_docs)] // it's an example
|
||||||
|
|
||||||
|
mod ctx;
|
||||||
|
mod daos;
|
||||||
|
mod migrations;
|
||||||
|
mod screens;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use ctx::ClientCTX;
|
||||||
|
use eframe::egui;
|
||||||
|
use mls_rs::client_builder::BaseConfig;
|
||||||
|
use mls_rs::client_builder::WithCryptoProvider;
|
||||||
|
use mls_rs::client_builder::WithGroupStateStorage;
|
||||||
|
use mls_rs::client_builder::WithIdentityProvider;
|
||||||
|
use mls_rs::client_builder::WithKeyPackageRepo;
|
||||||
|
use mls_rs::client_builder::WithPskStore;
|
||||||
|
use mls_rs::identity::basic::BasicIdentityProvider;
|
||||||
|
use mls_rs_crypto_openssl::OpensslCryptoProvider;
|
||||||
|
use screens::Screens;
|
||||||
|
use tracing::level_filters::LevelFilter;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
use tracing_subscriber::FmtSubscriber;
|
||||||
|
|
||||||
|
pub type CliMlsConfig = WithIdentityProvider<
|
||||||
|
BasicIdentityProvider,
|
||||||
|
WithGroupStateStorage<
|
||||||
|
daos::MlsCliGroupStateStorage,
|
||||||
|
WithPskStore<
|
||||||
|
daos::MlsCliPreSharedKeyStorage,
|
||||||
|
WithKeyPackageRepo<
|
||||||
|
daos::MlsCliKeyPackageStorage,
|
||||||
|
WithCryptoProvider<OpensslCryptoProvider, BaseConfig>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
pub type MlsClient = mls_rs::Client<CliMlsConfig>;
|
||||||
|
pub type MlsGroup = mls_rs::Group<CliMlsConfig>;
|
||||||
|
|
||||||
|
fn main() -> eframe::Result {
|
||||||
|
// env_logger::init(); // Log to stderr (if you run with `RUST_LOG=debug`).
|
||||||
|
FmtSubscriber::builder()
|
||||||
|
.with_env_filter(
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.from_env_lossy(),
|
||||||
|
)
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let options = eframe::NativeOptions {
|
||||||
|
viewport: egui::ViewportBuilder::default().with_title("My little chat"),
|
||||||
|
centered: true,
|
||||||
|
vsync: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let machine_id = machine_uid::get().expect("Cannot determine machine uid");
|
||||||
|
let project_dir = directories::ProjectDirs::from("com", "le-memese", "mlc")
|
||||||
|
.expect("Cannot find local project directory");
|
||||||
|
let data_dir = project_dir.data_local_dir();
|
||||||
|
|
||||||
|
tracing::info!("Data dir: {}", data_dir.display());
|
||||||
|
tracing::info!("Machine id: {}", machine_id);
|
||||||
|
|
||||||
|
std::fs::create_dir_all(&data_dir).expect("Cannot create app data dir");
|
||||||
|
|
||||||
|
let mut sqlite_conn =
|
||||||
|
rusqlite::Connection::open(data_dir.join("data.sqlite3")).expect(&format!(
|
||||||
|
"Can't open database file at {}",
|
||||||
|
data_dir.join("data").display()
|
||||||
|
));
|
||||||
|
|
||||||
|
{
|
||||||
|
sqlite_conn
|
||||||
|
.pragma_update_and_check(None, "journal_mode", "WAL", |_| Ok(()))
|
||||||
|
.expect("Cannot set WAL mode on the database");
|
||||||
|
migrations::MIGRATIONS
|
||||||
|
.to_latest(&mut sqlite_conn)
|
||||||
|
.expect("Cannot run migrations");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ctx = ClientCTX::new(Arc::new(Mutex::new(sqlite_conn)));
|
||||||
|
ctx.machine_id = machine_id;
|
||||||
|
|
||||||
|
let app = MyApp::new(ctx);
|
||||||
|
|
||||||
|
eframe::run_native(
|
||||||
|
"My egui App",
|
||||||
|
options,
|
||||||
|
Box::new(|cc| {
|
||||||
|
// This gives us image support:
|
||||||
|
egui_extras::install_image_loaders(&cc.egui_ctx);
|
||||||
|
Ok(Box::new(app))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MyApp {
|
||||||
|
login: screens::LoginScreen,
|
||||||
|
chats: screens::ChatsScreen,
|
||||||
|
register: screens::RegisterScreen,
|
||||||
|
ctx: ctx::ClientCTX,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MyApp {
|
||||||
|
fn new(ctx: ClientCTX) -> Self {
|
||||||
|
Self {
|
||||||
|
login: screens::LoginScreen::default(),
|
||||||
|
chats: screens::ChatsScreen::default(),
|
||||||
|
register: screens::RegisterScreen::default(),
|
||||||
|
ctx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl eframe::App for MyApp {
|
||||||
|
fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
|
||||||
|
egui::CentralPanel::default().show(ctx, |ui| match &self.ctx.current_screen {
|
||||||
|
Screens::Login => self.login.update(&mut self.ctx, ui),
|
||||||
|
Screens::Chats => self.chats.update(ctx, &mut self.ctx, ui),
|
||||||
|
Screens::Register => self.register.update(&mut self.ctx, ui),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
83
mlc-client/src/migrations.rs
Normal file
83
mlc-client/src/migrations.rs
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
|
use rusqlite_migration::{M, Migrations};
|
||||||
|
|
||||||
|
pub static MIGRATIONS: LazyLock<Migrations> = LazyLock::new(|| {
|
||||||
|
Migrations::new(vec![
|
||||||
|
// Initial migration
|
||||||
|
M::up(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE mls_group (
|
||||||
|
group_id BLOB PRIMARY KEY,
|
||||||
|
snapshot BLOB NOT NULL
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE epoch (
|
||||||
|
group_id BLOB,
|
||||||
|
epoch_id INTEGER,
|
||||||
|
epoch_data BLOB NOT NULL,
|
||||||
|
FOREIGN KEY (group_id) REFERENCES mls_group (group_id) ON DELETE CASCADE
|
||||||
|
PRIMARY KEY (group_id, epoch_id)
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE key_package (
|
||||||
|
id BLOB PRIMARY KEY,
|
||||||
|
expiration INTEGER,
|
||||||
|
data BLOB NOT NULL
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
CREATE INDEX key_package_exp ON key_package (expiration);
|
||||||
|
|
||||||
|
CREATE TABLE psk (
|
||||||
|
psk_id BLOB PRIMARY KEY,
|
||||||
|
data BLOB NOT NULL
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS mls_cli_messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
group_id BLOB NOT NULL,
|
||||||
|
sender_id INT NOT NULL,
|
||||||
|
data BLOB NOT NULL,
|
||||||
|
epoch_id UNSIGNED BIG INT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (group_id) REFERENCES mls_group (group_id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_group_id ON mls_cli_messages(group_id);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS mls_cli_userdata (
|
||||||
|
user_id TEXT PRIMARY KEY,
|
||||||
|
secret_key BLOB NOT NULL,
|
||||||
|
public_key BLOB NOT NULL
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS mls_cli_group_names (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.down(
|
||||||
|
r#"
|
||||||
|
DROP TABLE mls_cli_messages;
|
||||||
|
DROP TABLE mls_cli_userdata;
|
||||||
|
DROP TABLE mls_group;
|
||||||
|
DROP TABLE epoch;
|
||||||
|
DROP TABLE key_package;
|
||||||
|
DROP TABLE psk;
|
||||||
|
DROP TABLE mls_cli_group_names;
|
||||||
|
"#,
|
||||||
|
),
|
||||||
|
])
|
||||||
|
});
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// Validating that migrations are correctly defined. It is enough to test in the sync context,
|
||||||
|
// because under the hood, tokio_rusqlite executes the migrations in a sync context anyway.
|
||||||
|
#[test]
|
||||||
|
fn migrations_test() {
|
||||||
|
assert!(MIGRATIONS.validate().is_ok());
|
||||||
|
}
|
||||||
|
}
|
25
mlc-client/src/screens/chats.rs
Normal file
25
mlc-client/src/screens/chats.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
use egui::{ComboBox, Label, SidePanel, Ui, Widget};
|
||||||
|
|
||||||
|
use crate::ctx::ClientCTX;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ChatsScreen {
|
||||||
|
chats: Vec<String>,
|
||||||
|
selected_chat: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatsScreen {
|
||||||
|
pub fn update(&mut self, egui_ctx: &egui::Context, ctx: &mut ClientCTX, ui: &mut Ui) {
|
||||||
|
SidePanel::new(egui::panel::Side::Left, "chats_list")
|
||||||
|
.resizable(false)
|
||||||
|
.show_inside(ui, |ui| {
|
||||||
|
for chat in &self.chats {
|
||||||
|
let label = Label::new(format!("{}", chat)).ui(ui);
|
||||||
|
if label.clicked() {
|
||||||
|
self.selected_chat = Some(format!("Chat {}", chat));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
ui.label(&format!("Chat : {:?}", self.selected_chat));
|
||||||
|
}
|
||||||
|
}
|
45
mlc-client/src/screens/login.rs
Normal file
45
mlc-client/src/screens/login.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use egui::Color32;
|
||||||
|
|
||||||
|
use crate::{Screens, ctx::ClientCTX};
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct LoginScreen {
|
||||||
|
email: String,
|
||||||
|
password: String,
|
||||||
|
highlight_password: bool,
|
||||||
|
highlight_email: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoginScreen {
|
||||||
|
pub fn update(&mut self, ctx: &mut ClientCTX, ui: &mut egui::Ui) {
|
||||||
|
ui.vertical_centered(|ui| {
|
||||||
|
let mut email_field = egui::TextEdit::singleline(&mut self.email).hint_text("Email");
|
||||||
|
let mut password_field = egui::TextEdit::singleline(&mut self.password)
|
||||||
|
.password(true)
|
||||||
|
.hint_text("Password");
|
||||||
|
if self.highlight_email {
|
||||||
|
email_field = email_field.background_color(Color32::RED);
|
||||||
|
}
|
||||||
|
if self.highlight_password {
|
||||||
|
password_field = password_field.background_color(Color32::RED);
|
||||||
|
}
|
||||||
|
ui.add(email_field);
|
||||||
|
ui.add(password_field);
|
||||||
|
if ui.button("Login").clicked() || ui.input(|inp| inp.key_pressed(egui::Key::Enter)) {
|
||||||
|
self.highlight_email = self.email.is_empty();
|
||||||
|
self.highlight_password = self.password.is_empty();
|
||||||
|
if self.highlight_password || self.highlight_email {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ctx.password = self.password.clone();
|
||||||
|
ctx.email = self.email.clone();
|
||||||
|
self.password.clear();
|
||||||
|
ctx.current_screen = Screens::Chats;
|
||||||
|
}
|
||||||
|
if ui.button("Register").clicked() {
|
||||||
|
*self = Self::default();
|
||||||
|
ctx.current_screen = Screens::Register;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
15
mlc-client/src/screens/mod.rs
Normal file
15
mlc-client/src/screens/mod.rs
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
mod chats;
|
||||||
|
mod login;
|
||||||
|
mod register;
|
||||||
|
|
||||||
|
pub use chats::ChatsScreen;
|
||||||
|
pub use login::LoginScreen;
|
||||||
|
pub use register::RegisterScreen;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub enum Screens {
|
||||||
|
#[default]
|
||||||
|
Login,
|
||||||
|
Register,
|
||||||
|
Chats,
|
||||||
|
}
|
61
mlc-client/src/screens/register.rs
Normal file
61
mlc-client/src/screens/register.rs
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
use egui::Color32;
|
||||||
|
|
||||||
|
use crate::{ctx::ClientCTX, screens::Screens};
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct RegisterScreen {
|
||||||
|
email: String,
|
||||||
|
password: String,
|
||||||
|
repeat_password: String,
|
||||||
|
highlight_password: bool,
|
||||||
|
highlight_repeat_password: bool,
|
||||||
|
highlight_email: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RegisterScreen {
|
||||||
|
pub fn update(&mut self, ctx: &mut ClientCTX, ui: &mut egui::Ui) {
|
||||||
|
ui.vertical_centered(|ui| {
|
||||||
|
let mut email_field = egui::TextEdit::singleline(&mut self.email).hint_text("Email");
|
||||||
|
let mut password_field = egui::TextEdit::singleline(&mut self.password)
|
||||||
|
.password(true)
|
||||||
|
.hint_text("Password");
|
||||||
|
let mut repeat_password_field = egui::TextEdit::singleline(&mut self.repeat_password)
|
||||||
|
.password(true)
|
||||||
|
.hint_text("Repeat password");
|
||||||
|
if self.highlight_email {
|
||||||
|
email_field = email_field.background_color(Color32::RED);
|
||||||
|
}
|
||||||
|
if self.highlight_password {
|
||||||
|
password_field = password_field.background_color(Color32::RED);
|
||||||
|
}
|
||||||
|
if self.highlight_repeat_password {
|
||||||
|
repeat_password_field = repeat_password_field.background_color(Color32::RED);
|
||||||
|
}
|
||||||
|
ui.add(email_field);
|
||||||
|
ui.add(password_field);
|
||||||
|
ui.add(repeat_password_field);
|
||||||
|
if ui.button("Register").clicked() || ui.input(|inp| inp.key_pressed(egui::Key::Enter))
|
||||||
|
{
|
||||||
|
self.highlight_email = self.email.is_empty();
|
||||||
|
self.highlight_password = self.password.is_empty();
|
||||||
|
if self.password != self.repeat_password {
|
||||||
|
self.highlight_repeat_password = true;
|
||||||
|
self.highlight_password = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if self.highlight_password || self.highlight_email || self.highlight_repeat_password
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ctx.password = self.password.clone();
|
||||||
|
ctx.email = self.email.clone();
|
||||||
|
self.password.clear();
|
||||||
|
ctx.current_screen = Screens::Login;
|
||||||
|
}
|
||||||
|
if ui.button("Back").clicked() {
|
||||||
|
*self = Self::default();
|
||||||
|
ctx.current_screen = Screens::Login;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
25
mlc-server/Cargo.toml
Normal file
25
mlc-server/Cargo.toml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
[package]
|
||||||
|
name = "mlc-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
mls-rs.workspace = true
|
||||||
|
mls-rs-core.workspace = true
|
||||||
|
mls-rs-crypto-openssl.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
reqwest.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
|
tracing-subscriber.workspace = true
|
||||||
|
|
||||||
|
clap = { version = "^4", features = ["derive", "env"] }
|
||||||
|
tokio = { version = "1.44.1", features = ["full"] }
|
||||||
|
sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio"] }
|
||||||
|
dotenv = "0.15.0"
|
||||||
|
axum = { version = "0.8.1", features = ["ws"] }
|
||||||
|
rustc-hash = { version = "2.1.1", features = ["rand"] }
|
||||||
|
uuid = { version = "1.16.0", features = ["v4"] }
|
||||||
|
futures = "0.3.31"
|
||||||
|
|
0
mlc-server/Dockerfile
Normal file
0
mlc-server/Dockerfile
Normal file
5
mlc-server/build.rs
Normal file
5
mlc-server/build.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
// generated by `sqlx migrate build-script`
|
||||||
|
fn main() {
|
||||||
|
// trigger recompilation when a new migration is added
|
||||||
|
println!("cargo:rerun-if-changed=migrations");
|
||||||
|
}
|
6
mlc-server/migrations/20250323085527_init.down.sql
Normal file
6
mlc-server/migrations/20250323085527_init.down.sql
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
DROP TABLE `mlc_keypackage`;
|
||||||
|
DROP TABLE `mlc_group_participant`;
|
||||||
|
DROP TABLE `mlc_group`;
|
||||||
|
DROP TABLE `mlc_device`;
|
||||||
|
DROP TABLE `mlc_user`;
|
||||||
|
DROP TABLE `mlc_inbox`;
|
40
mlc-server/migrations/20250323085527_init.up.sql
Normal file
40
mlc-server/migrations/20250323085527_init.up.sql
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
CREATE TABLE mlc_user(
|
||||||
|
username TEXT PRIMARY KEY,
|
||||||
|
password TEXT,
|
||||||
|
email TEXT
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE mlc_device(
|
||||||
|
id TEXT,
|
||||||
|
username TEXT,
|
||||||
|
PRIMARY KEY (id, username),
|
||||||
|
FOREIGN KEY (username) REFERENCES mlc_user(username)
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE mlc_group(id TEXT PRIMARY KEY) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE mlc_group_participant(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
group_id BIGINT,
|
||||||
|
device_id TEXT,
|
||||||
|
username TEXT,
|
||||||
|
FOREIGN KEY (group_id) REFERENCES mlc_group(id),
|
||||||
|
FOREIGN KEY (device_id, username) REFERENCES mlc_device(id, username)
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE mlc_keypackage(
|
||||||
|
hash_ref BLOB PRIMARY KEY,
|
||||||
|
device_id TEXT,
|
||||||
|
username TEXT,
|
||||||
|
expiration BIGINT,
|
||||||
|
given BOOLEAN,
|
||||||
|
data BLOB,
|
||||||
|
FOREIGN KEY (device_id, username) REFERENCES mlc_device(id, username)
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE mlc_inbox(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
device_id TEXT,
|
||||||
|
username TEXT,
|
||||||
|
FOREIGN KEY (device_id, username) REFERENCES mlc_device(id, username)
|
||||||
|
) WITHOUT ROWID;
|
14
mlc-server/src/args.rs
Normal file
14
mlc-server/src/args.rs
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#[derive(Debug, Clone, clap::Parser)]
|
||||||
|
pub struct MlcConfig {
|
||||||
|
#[arg(long = "port", default_value = "8080", env = "MLC_SERVER_PORT")]
|
||||||
|
pub server_port: u16,
|
||||||
|
#[arg(long = "host", default_value = "0.0.0.0", env = "MLC_SERVER_HOST")]
|
||||||
|
pub server_host: String,
|
||||||
|
#[arg(long, short, default_value = "mlc.sqlite3", env = "MLC_DB_PATH")]
|
||||||
|
pub db_path: String,
|
||||||
|
#[arg(long, short, default_value = "secret", env = "MLC_JWT_SECRET")]
|
||||||
|
pub jwt_secret: String,
|
||||||
|
|
||||||
|
#[arg(long, short, default_value = "100", env = "MLC_PER_USER_MSG_BUFFER")]
|
||||||
|
pub per_user_msg_buffer: usize,
|
||||||
|
}
|
0
mlc-server/src/daos/mod.rs
Normal file
0
mlc-server/src/daos/mod.rs
Normal file
0
mlc-server/src/jwt.rs
Normal file
0
mlc-server/src/jwt.rs
Normal file
163
mlc-server/src/main.rs
Normal file
163
mlc-server/src/main.rs
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
use std::{fs::OpenOptions, path::Path};
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
Extension,
|
||||||
|
extract::{State, ws::Message},
|
||||||
|
http::HeaderMap,
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
use clap::Parser;
|
||||||
|
use futures::SinkExt;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use mls_rs::mls_rs_codec::MlsDecode;
|
||||||
|
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
|
||||||
|
use state::MlcState;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tracing::level_filters::LevelFilter;
|
||||||
|
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||||
|
|
||||||
|
mod args;
|
||||||
|
mod daos;
|
||||||
|
mod routes;
|
||||||
|
mod services;
|
||||||
|
mod state;
|
||||||
|
|
||||||
|
async fn start() -> anyhow::Result<()> {
|
||||||
|
let config = args::MlcConfig::parse();
|
||||||
|
// Get the database URL from environment variables
|
||||||
|
{
|
||||||
|
if let Some(parent) = Path::new(&config.db_path).parent() {
|
||||||
|
std::fs::create_dir_all(&parent)?;
|
||||||
|
}
|
||||||
|
OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.write(true)
|
||||||
|
.open(&config.db_path)?;
|
||||||
|
}
|
||||||
|
tracing::info!("Database path: {}", config.db_path);
|
||||||
|
let sqlite_pool = SqlitePoolOptions::new()
|
||||||
|
.max_connections(10)
|
||||||
|
.min_connections(4)
|
||||||
|
.max_lifetime(None)
|
||||||
|
.connect(&format!("sqlite:{}", config.db_path))
|
||||||
|
.await
|
||||||
|
.expect("Failed to connect to database");
|
||||||
|
|
||||||
|
// Run migrations
|
||||||
|
sqlx::migrate!("./migrations")
|
||||||
|
.run(&sqlite_pool)
|
||||||
|
.await
|
||||||
|
.expect("Failed to run migrations");
|
||||||
|
|
||||||
|
let state = state::MlcState {
|
||||||
|
sqlite_pool,
|
||||||
|
ws_manager: services::WSManager::default(),
|
||||||
|
config: config.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let app = axum::Router::new()
|
||||||
|
.route("/ws", axum::routing::get(ws_handler))
|
||||||
|
.with_state(state);
|
||||||
|
|
||||||
|
let listener = TcpListener::bind((config.server_host, config.server_port)).await?;
|
||||||
|
tracing::info!("Listening on: {}", listener.local_addr()?);
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ws_handler(
|
||||||
|
ws: axum::extract::ws::WebSocketUpgrade,
|
||||||
|
headers: HeaderMap,
|
||||||
|
State(db_pool): State<MlcState>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
ws.on_upgrade(move |socket| {
|
||||||
|
let username = headers.get("username").unwrap().to_str().unwrap();
|
||||||
|
let (ws_writer, ws_reader) = socket.split();
|
||||||
|
handle_messages(ws_writer, ws_reader, username.to_string(), db_pool)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_messages<W, R>(
|
||||||
|
mut ws_writer: W,
|
||||||
|
mut ws_reader: R,
|
||||||
|
username: String,
|
||||||
|
state: MlcState,
|
||||||
|
) where
|
||||||
|
W: futures::Sink<Message> + Unpin,
|
||||||
|
R: futures::Stream<Item = Result<Message, axum::Error>> + Unpin,
|
||||||
|
<W as futures::Sink<Message>>::Error: std::fmt::Display,
|
||||||
|
{
|
||||||
|
let (messages_tx, mut messages_rx) =
|
||||||
|
tokio::sync::mpsc::channel::<String>(state.config.per_user_msg_buffer);
|
||||||
|
let ws_id = state.ws_manager.add_client(&username, messages_tx).await;
|
||||||
|
tracing::Span::current().record("ws_id", ws_id);
|
||||||
|
let mut messages = vec![];
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
msg = messages_rx.recv() => {
|
||||||
|
if let Some(msg) = msg {
|
||||||
|
ws_writer.send(Message::Text(msg.into())).await.ok();
|
||||||
|
}else{
|
||||||
|
tracing::debug!("Messages channel has been closed");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
msg = ws_reader.next() => {
|
||||||
|
let Some(msg_res) = msg else {
|
||||||
|
tracing::debug!("Websocket connection has been closed");
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
let Ok(msg) = msg_res else {
|
||||||
|
tracing::warn!("Cannot receive message from client: {}", msg_res.err().unwrap());
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
match msg {
|
||||||
|
Message::Close(_) => {
|
||||||
|
tracing::debug!("Client has closed the connection");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Message::Ping(bytes) => {ws_writer.send(Message::Pong(bytes)).await.ok();},
|
||||||
|
Message::Text(utf8_bytes) => {
|
||||||
|
tracing::info!("Received message from ws: {:?}", utf8_bytes);
|
||||||
|
if let Err(err) = ws_writer.send(Message::Text(utf8_bytes)).await {
|
||||||
|
tracing::error!("Failed to send message back: {}", err);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Message::Binary(bytes) => {
|
||||||
|
// bytes.
|
||||||
|
let mut raw_bytes = bytes.as_ref();
|
||||||
|
while raw_bytes.len() > 0 {
|
||||||
|
if let Ok(msg) = mls_rs::MlsMessage::mls_decode(&mut raw_bytes){
|
||||||
|
messages.push(msg);
|
||||||
|
} else {break}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("{:?}", messages);
|
||||||
|
messages.truncate(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
// Load environment variables from .env file
|
||||||
|
dotenv::dotenv().ok();
|
||||||
|
FmtSubscriber::builder()
|
||||||
|
.with_env_filter(
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.from_env_lossy(),
|
||||||
|
)
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// Start the HTTP server
|
||||||
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
rt.block_on(start())
|
||||||
|
}
|
0
mlc-server/src/routes/mod.rs
Normal file
0
mlc-server/src/routes/mod.rs
Normal file
3
mlc-server/src/services/mod.rs
Normal file
3
mlc-server/src/services/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
mod ws_manager;
|
||||||
|
|
||||||
|
pub use ws_manager::WSManager;
|
148
mlc-server/src/services/ws_manager.rs
Normal file
148
mlc-server/src/services/ws_manager.rs
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
use tokio::sync::{RwLock, mpsc::Sender};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default)]
|
||||||
|
pub struct WSManager {
|
||||||
|
connected_clients: Arc<RwLock<FxHashMap<String, Vec<Sender<String>>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WSManager {
|
||||||
|
#[tracing::instrument(level = "debug", skip(self, sender))]
|
||||||
|
pub async fn add_client(&self, user_id: &str, sender: Sender<String>) -> usize {
|
||||||
|
let mut connected_clients = self.connected_clients.write().await;
|
||||||
|
let senders = connected_clients
|
||||||
|
.entry(user_id.to_string())
|
||||||
|
.or_insert_with(Vec::new);
|
||||||
|
senders.push(sender);
|
||||||
|
// Minus one, because we use zero-based indexing.
|
||||||
|
senders.len() - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(level = "debug", skip(self))]
|
||||||
|
pub async fn send_message(&self, user_id: &str, message: &str) -> bool {
|
||||||
|
let mut dead_sockets = vec![];
|
||||||
|
let mut message_sent = false;
|
||||||
|
let user_id_str = user_id.to_string();
|
||||||
|
// Here we are defining a scope to drop the read lock as soon as we are done with it.
|
||||||
|
// Because later we will need to acquire a write lock on the same data.
|
||||||
|
// If we won't drop the read lock, we will get a deadlock.
|
||||||
|
{
|
||||||
|
let connected_clients = self.connected_clients.read().await;
|
||||||
|
if let Some(senders) = connected_clients.get(&user_id_str) {
|
||||||
|
for (id, sender) in senders.iter().enumerate() {
|
||||||
|
if sender.send(message.to_string()).await.is_err() {
|
||||||
|
dead_sockets.push(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// The delivery was successful only if we have
|
||||||
|
// sent the message to at least one socket.
|
||||||
|
if dead_sockets.len() < senders.len() {
|
||||||
|
message_sent = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If all sockets that are connected to the user are running fine, we can return.
|
||||||
|
if dead_sockets.is_empty() {
|
||||||
|
return message_sent;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.remove_clients(user_id, dead_sockets).await;
|
||||||
|
|
||||||
|
message_sent
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(level = "debug", skip(self))]
|
||||||
|
pub async fn remove_clients(&self, user_id: &str, ws_ids: Vec<usize>) {
|
||||||
|
let mut connected_clients = self.connected_clients.write().await;
|
||||||
|
let mut senders_left = 0;
|
||||||
|
if let Some(senders) = connected_clients.get_mut(&user_id.to_string()) {
|
||||||
|
for id in ws_ids {
|
||||||
|
if id >= senders.len() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
senders.swap_remove(id);
|
||||||
|
}
|
||||||
|
senders_left = senders.len();
|
||||||
|
}
|
||||||
|
if senders_left == 0 {
|
||||||
|
connected_clients.remove(&user_id.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[tokio::test]
|
||||||
|
async fn add_client() {
|
||||||
|
let ws_manager = super::WSManager::default();
|
||||||
|
let (sender, _) = tokio::sync::mpsc::channel(1);
|
||||||
|
let uid = uuid::Uuid::new_v4().to_string();
|
||||||
|
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||||
|
assert_eq!(client_id, 0);
|
||||||
|
let connected_clients = ws_manager.connected_clients.read().await;
|
||||||
|
let sockets = connected_clients.get(&uid.to_string()).unwrap();
|
||||||
|
assert_eq!(sockets.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sending_message() {
|
||||||
|
let ws_manager = super::WSManager::default();
|
||||||
|
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||||
|
let uid = uuid::Uuid::new_v4().to_string();
|
||||||
|
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||||
|
assert_eq!(client_id, 0);
|
||||||
|
ws_manager.send_message(&uid, "test").await;
|
||||||
|
let message = receiver.recv().await;
|
||||||
|
assert_eq!(message, Some("test".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn multiple_clients() {
|
||||||
|
let ws_manager = super::WSManager::default();
|
||||||
|
let mut receivers = vec![];
|
||||||
|
let uid = uuid::Uuid::new_v4().to_string();
|
||||||
|
let num_clients = 10;
|
||||||
|
for i in 0..num_clients {
|
||||||
|
let (sender, receiver) = tokio::sync::mpsc::channel(1);
|
||||||
|
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||||
|
assert_eq!(client_id, i);
|
||||||
|
receivers.push(receiver);
|
||||||
|
}
|
||||||
|
ws_manager.send_message(&uid, "test").await;
|
||||||
|
for mut receiver in receivers {
|
||||||
|
let message = receiver.recv().await;
|
||||||
|
assert_eq!(message, Some("test".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn remove_client() {
|
||||||
|
let ws_manager = super::WSManager::default();
|
||||||
|
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||||
|
let uid = uuid::Uuid::new_v4().to_string();
|
||||||
|
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||||
|
assert_eq!(client_id, 0);
|
||||||
|
ws_manager.remove_clients(&uid, vec![0, 1, 2]).await;
|
||||||
|
ws_manager.send_message(&uid, "test").await;
|
||||||
|
let message = receiver.recv().await;
|
||||||
|
assert_eq!(message, None);
|
||||||
|
let all_clients = ws_manager.connected_clients.read().await;
|
||||||
|
assert_eq!(all_clients.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn removing_dead_client() {
|
||||||
|
let ws_manager = super::WSManager::default();
|
||||||
|
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||||
|
let uid = uuid::Uuid::new_v4().to_string();
|
||||||
|
let client_id = ws_manager.add_client(&uid, sender).await;
|
||||||
|
receiver.close();
|
||||||
|
assert_eq!(client_id, 0);
|
||||||
|
let delivered = ws_manager.send_message(&uid, "test").await;
|
||||||
|
assert_eq!(delivered, false);
|
||||||
|
let all_clients = ws_manager.connected_clients.read().await;
|
||||||
|
assert_eq!(all_clients.len(), 0);
|
||||||
|
}
|
||||||
|
}
|
10
mlc-server/src/state.rs
Normal file
10
mlc-server/src/state.rs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
use sqlx::SqlitePool;
|
||||||
|
|
||||||
|
use crate::{args::MlcConfig, services::WSManager};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct MlcState {
|
||||||
|
pub sqlite_pool: SqlitePool,
|
||||||
|
pub ws_manager: WSManager,
|
||||||
|
pub config: MlcConfig,
|
||||||
|
}
|
Reference in New Issue
Block a user