From 5e2aee9248a00c8b213a8e07e4796d668bff519c Mon Sep 17 00:00:00 2001 From: Joris Date: Fri, 7 Feb 2025 09:17:26 +0100 Subject: Apply SQL migrations at startup --- src/db/files.rs | 193 +++++++++++++++++++++++++++++++++ src/db/migrations/01-init.sql | 7 ++ src/db/migrations/02-strict-tables.sql | 15 +++ src/db/mod.rs | 47 ++++++++ src/db/utils.rs | 3 + 5 files changed, 265 insertions(+) create mode 100644 src/db/files.rs create mode 100644 src/db/migrations/01-init.sql create mode 100644 src/db/migrations/02-strict-tables.sql create mode 100644 src/db/mod.rs create mode 100644 src/db/utils.rs (limited to 'src/db') diff --git a/src/db/files.rs b/src/db/files.rs new file mode 100644 index 0000000..995a2ed --- /dev/null +++ b/src/db/files.rs @@ -0,0 +1,193 @@ +use chrono::{DateTime, Local}; +use tokio_rusqlite::{named_params, Connection, Result}; + +use crate::db::utils; +use crate::model::{decode_datetime, encode_datetime, File}; + +pub async fn insert(conn: &Connection, file: File) -> Result<()> { + let query = r#" + INSERT INTO files(id, created_at, expires_at, filename, content_length) + VALUES (:id, datetime(), :expires_at, :name, :content_length) + "#; + + conn.call(move |conn| { + conn.execute( + query, + named_params![ + ":id": file.id, + ":expires_at": encode_datetime(file.expires_at), + ":name": file.name, + ":content_length": file.content_length + ], + ) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await + .map(|_| ()) +} + +pub async fn get(conn: &Connection, file_id: impl Into) -> Result> { + let file_id = file_id.into(); + + let query = r#" + SELECT filename, expires_at, content_length + FROM files + WHERE + id = :id + AND expires_at > datetime() + "#; + + conn.call(move |conn| { + let mut stmt = conn.prepare(query)?; + + let mut iter = stmt.query_map(named_params![":id": file_id], |row| { + let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); + Ok(res) + })?; + + match iter.next() { + Some(Ok((filename, expires_at, content_length))) => { + match decode_datetime(&expires_at) { + Some(expires_at) => Ok(Some(File { + id: file_id, + name: filename, + expires_at, + content_length, + })), + _ => Err(utils::rusqlite_other_error(format!( + "Error decoding datetime: {expires_at}" + ))), + } + } + Some(_) => Err(utils::rusqlite_other_error("Error reading file in DB")), + None => Ok(None), + } + }) + .await +} + +pub async fn list_expire_after(conn: &Connection, time: DateTime) -> Result> { + let query = r#" + SELECT id + FROM files + WHERE expires_at > :expires_at + "#; + + conn.call(move |conn| { + let mut stmt = conn.prepare(query)?; + + let iter = stmt.query_map(named_params![":expires_at": encode_datetime(time)], |row| { + row.get(0) + })?; + + let mut res = vec![]; + for id in iter { + res.push(id?) + } + Ok(res) + }) + .await +} + +pub async fn remove_expire_before(conn: &Connection, time: DateTime) -> Result<()> { + let query = r#" + DELETE FROM files + WHERE expires_at <= :expires_at + "#; + + conn.call(move |conn| { + conn.execute(query, named_params![":expires_at": encode_datetime(time)])?; + Ok(()) + }) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::{generate_file_id, local_time}; + use chrono::Duration; + use std::collections::HashSet; + use std::ops::Add; + + #[tokio::test] + async fn test_insert_and_get_file() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert_eq!(file_res.unwrap(), Some(file)); + } + + #[tokio::test] + async fn test_expired_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::zero()); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_wrong_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, "wrong-id".to_string()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_list_non_expirable() { + let conn = get_connection().await; + let file_expire = dummy_file(Duration::zero()); + let file_no_expire_1 = dummy_file(Duration::minutes(1)); + let file_no_expire_2 = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file_expire.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_1.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_2.clone()).await.is_ok()); + let list = list_expire_after(&conn, Local::now()).await; + assert!(list.is_ok()); + assert_eq!( + HashSet::from_iter(list.unwrap().iter()), + HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) + ) + } + + #[tokio::test] + async fn test_remove_expire_before() { + let conn = get_connection().await; + let file_1 = dummy_file(Duration::zero()); + let file_2 = dummy_file(Duration::zero()); + let file_3 = dummy_file(Duration::minutes(1)); + let file_4 = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file_1.clone()).await.is_ok()); + assert!(insert(&conn, file_2.clone()).await.is_ok()); + assert!(insert(&conn, file_3.clone()).await.is_ok()); + assert!(insert(&conn, file_4.clone()).await.is_ok()); + assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); + assert!(get(&conn, file_1.id).await.unwrap().is_none()); + assert!(get(&conn, file_2.id).await.unwrap().is_none()); + assert!(get(&conn, file_3.id).await.unwrap().is_some()); + assert!(get(&conn, file_4.id).await.unwrap().is_some()); + } + + fn dummy_file(td: Duration) -> File { + File { + id: generate_file_id(), + name: "foo".to_string(), + expires_at: local_time().add(td), + content_length: 100, + } + } + + async fn get_connection() -> Connection { + let conn = Connection::open_in_memory().await.unwrap(); + let res = crate::db::apply_migrations(&conn).await; + assert!(res.is_ok()); + conn + } +} diff --git a/src/db/migrations/01-init.sql b/src/db/migrations/01-init.sql new file mode 100644 index 0000000..75abc54 --- /dev/null +++ b/src/db/migrations/01-init.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS "files" ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + filename TEXT NOT NULL, + content_length INTEGER NOT NULL +); diff --git a/src/db/migrations/02-strict-tables.sql b/src/db/migrations/02-strict-tables.sql new file mode 100644 index 0000000..433ef39 --- /dev/null +++ b/src/db/migrations/02-strict-tables.sql @@ -0,0 +1,15 @@ +ALTER TABLE "files" RENAME TO "files_non_strict"; + +CREATE TABLE IF NOT EXISTS "files" ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + filename TEXT NOT NULL, + content_length INTEGER NOT NULL +) STRICT; + +INSERT INTO files (id, created_at, expires_at, filename, content_length) + SELECT id, created_at, expires_at, filename, content_length + FROM files_non_strict; + +DROP TABLE files_non_strict; diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..37769d2 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,47 @@ +use anyhow::{Error, Result}; +use rusqlite_migration::{Migrations, M}; +use tokio_rusqlite::Connection; + +pub mod files; +mod utils; + +pub async fn init(path: &str) -> Result { + let connection = Connection::open(path) + .await + .map_err(|err| Error::msg(format!("Error opening connection: {err}")))?; + + apply_migrations(&connection).await?; + set_pragma(&connection, "foreign_keys", "ON").await?; + set_pragma(&connection, "journal_mode", "wal").await?; + Ok(connection) +} + +async fn apply_migrations(conn: &Connection) -> Result<()> { + let migrations = Migrations::new(vec![ + M::up(include_str!("migrations/01-init.sql")), + M::up(include_str!("migrations/02-strict-tables.sql")), + ]); + + Ok(conn + .call(move |conn| { + migrations + .to_latest(conn) + .map_err(|migration_err| tokio_rusqlite::Error::Other(Box::new(migration_err))) + }) + .await?) +} + +async fn set_pragma( + conn: &Connection, + key: impl Into, + value: impl Into, +) -> Result<()> { + let key = key.into(); + let value = value.into(); + Ok(conn + .call(move |conn| { + conn.pragma_update(None, &key, &value) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await?) +} diff --git a/src/db/utils.rs b/src/db/utils.rs new file mode 100644 index 0000000..0b3e029 --- /dev/null +++ b/src/db/utils.rs @@ -0,0 +1,3 @@ +pub fn rusqlite_other_error(msg: impl Into) -> tokio_rusqlite::Error { + tokio_rusqlite::Error::Other(msg.into().into()) +} -- cgit v1.2.3