aboutsummaryrefslogtreecommitdiff
path: root/src/db
diff options
context:
space:
mode:
Diffstat (limited to 'src/db')
-rw-r--r--src/db/files.rs193
-rw-r--r--src/db/migrations/01-init.sql7
-rw-r--r--src/db/migrations/02-strict-tables.sql15
-rw-r--r--src/db/mod.rs47
-rw-r--r--src/db/utils.rs3
5 files changed, 265 insertions, 0 deletions
diff --git a/src/db/files.rs b/src/db/files.rs
new file mode 100644
index 0000000..995a2ed
--- /dev/null
+++ b/src/db/files.rs
@@ -0,0 +1,193 @@
+use chrono::{DateTime, Local};
+use tokio_rusqlite::{named_params, Connection, Result};
+
+use crate::db::utils;
+use crate::model::{decode_datetime, encode_datetime, File};
+
+pub async fn insert(conn: &Connection, file: File) -> Result<()> {
+ let query = r#"
+ INSERT INTO files(id, created_at, expires_at, filename, content_length)
+ VALUES (:id, datetime(), :expires_at, :name, :content_length)
+ "#;
+
+ conn.call(move |conn| {
+ conn.execute(
+ query,
+ named_params![
+ ":id": file.id,
+ ":expires_at": encode_datetime(file.expires_at),
+ ":name": file.name,
+ ":content_length": file.content_length
+ ],
+ )
+ .map_err(tokio_rusqlite::Error::Rusqlite)
+ })
+ .await
+ .map(|_| ())
+}
+
+pub async fn get(conn: &Connection, file_id: impl Into<String>) -> Result<Option<File>> {
+ let file_id = file_id.into();
+
+ let query = r#"
+ SELECT filename, expires_at, content_length
+ FROM files
+ WHERE
+ id = :id
+ AND expires_at > datetime()
+ "#;
+
+ conn.call(move |conn| {
+ let mut stmt = conn.prepare(query)?;
+
+ let mut iter = stmt.query_map(named_params![":id": file_id], |row| {
+ let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?);
+ Ok(res)
+ })?;
+
+ match iter.next() {
+ Some(Ok((filename, expires_at, content_length))) => {
+ match decode_datetime(&expires_at) {
+ Some(expires_at) => Ok(Some(File {
+ id: file_id,
+ name: filename,
+ expires_at,
+ content_length,
+ })),
+ _ => Err(utils::rusqlite_other_error(format!(
+ "Error decoding datetime: {expires_at}"
+ ))),
+ }
+ }
+ Some(_) => Err(utils::rusqlite_other_error("Error reading file in DB")),
+ None => Ok(None),
+ }
+ })
+ .await
+}
+
+pub async fn list_expire_after(conn: &Connection, time: DateTime<Local>) -> Result<Vec<String>> {
+ let query = r#"
+ SELECT id
+ FROM files
+ WHERE expires_at > :expires_at
+ "#;
+
+ conn.call(move |conn| {
+ let mut stmt = conn.prepare(query)?;
+
+ let iter = stmt.query_map(named_params![":expires_at": encode_datetime(time)], |row| {
+ row.get(0)
+ })?;
+
+ let mut res = vec![];
+ for id in iter {
+ res.push(id?)
+ }
+ Ok(res)
+ })
+ .await
+}
+
+pub async fn remove_expire_before(conn: &Connection, time: DateTime<Local>) -> Result<()> {
+ let query = r#"
+ DELETE FROM files
+ WHERE expires_at <= :expires_at
+ "#;
+
+ conn.call(move |conn| {
+ conn.execute(query, named_params![":expires_at": encode_datetime(time)])?;
+ Ok(())
+ })
+ .await
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::model::{generate_file_id, local_time};
+ use chrono::Duration;
+ use std::collections::HashSet;
+ use std::ops::Add;
+
+ #[tokio::test]
+ async fn test_insert_and_get_file() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::minutes(1));
+ assert!(insert(&conn, file.clone()).await.is_ok());
+ let file_res = get(&conn, file.id.clone()).await;
+ assert!(file_res.is_ok());
+ assert_eq!(file_res.unwrap(), Some(file));
+ }
+
+ #[tokio::test]
+ async fn test_expired_file_err() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::zero());
+ assert!(insert(&conn, file.clone()).await.is_ok());
+ let file_res = get(&conn, file.id.clone()).await;
+ assert!(file_res.is_ok());
+ assert!(file_res.unwrap().is_none());
+ }
+
+ #[tokio::test]
+ async fn test_wrong_file_err() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::minutes(1));
+ assert!(insert(&conn, file.clone()).await.is_ok());
+ let file_res = get(&conn, "wrong-id".to_string()).await;
+ assert!(file_res.is_ok());
+ assert!(file_res.unwrap().is_none());
+ }
+
+ #[tokio::test]
+ async fn test_list_non_expirable() {
+ let conn = get_connection().await;
+ let file_expire = dummy_file(Duration::zero());
+ let file_no_expire_1 = dummy_file(Duration::minutes(1));
+ let file_no_expire_2 = dummy_file(Duration::minutes(1));
+ assert!(insert(&conn, file_expire.clone()).await.is_ok());
+ assert!(insert(&conn, file_no_expire_1.clone()).await.is_ok());
+ assert!(insert(&conn, file_no_expire_2.clone()).await.is_ok());
+ let list = list_expire_after(&conn, Local::now()).await;
+ assert!(list.is_ok());
+ assert_eq!(
+ HashSet::from_iter(list.unwrap().iter()),
+ HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id])
+ )
+ }
+
+ #[tokio::test]
+ async fn test_remove_expire_before() {
+ let conn = get_connection().await;
+ let file_1 = dummy_file(Duration::zero());
+ let file_2 = dummy_file(Duration::zero());
+ let file_3 = dummy_file(Duration::minutes(1));
+ let file_4 = dummy_file(Duration::minutes(1));
+ assert!(insert(&conn, file_1.clone()).await.is_ok());
+ assert!(insert(&conn, file_2.clone()).await.is_ok());
+ assert!(insert(&conn, file_3.clone()).await.is_ok());
+ assert!(insert(&conn, file_4.clone()).await.is_ok());
+ assert!(remove_expire_before(&conn, Local::now()).await.is_ok());
+ assert!(get(&conn, file_1.id).await.unwrap().is_none());
+ assert!(get(&conn, file_2.id).await.unwrap().is_none());
+ assert!(get(&conn, file_3.id).await.unwrap().is_some());
+ assert!(get(&conn, file_4.id).await.unwrap().is_some());
+ }
+
+ fn dummy_file(td: Duration) -> File {
+ File {
+ id: generate_file_id(),
+ name: "foo".to_string(),
+ expires_at: local_time().add(td),
+ content_length: 100,
+ }
+ }
+
+ async fn get_connection() -> Connection {
+ let conn = Connection::open_in_memory().await.unwrap();
+ let res = crate::db::apply_migrations(&conn).await;
+ assert!(res.is_ok());
+ conn
+ }
+}
diff --git a/src/db/migrations/01-init.sql b/src/db/migrations/01-init.sql
new file mode 100644
index 0000000..75abc54
--- /dev/null
+++ b/src/db/migrations/01-init.sql
@@ -0,0 +1,7 @@
+CREATE TABLE IF NOT EXISTS "files" (
+ id TEXT PRIMARY KEY,
+ created_at TEXT NOT NULL,
+ expires_at TEXT NOT NULL,
+ filename TEXT NOT NULL,
+ content_length INTEGER NOT NULL
+);
diff --git a/src/db/migrations/02-strict-tables.sql b/src/db/migrations/02-strict-tables.sql
new file mode 100644
index 0000000..433ef39
--- /dev/null
+++ b/src/db/migrations/02-strict-tables.sql
@@ -0,0 +1,15 @@
+ALTER TABLE "files" RENAME TO "files_non_strict";
+
+CREATE TABLE IF NOT EXISTS "files" (
+ id TEXT PRIMARY KEY,
+ created_at TEXT NOT NULL,
+ expires_at TEXT NOT NULL,
+ filename TEXT NOT NULL,
+ content_length INTEGER NOT NULL
+) STRICT;
+
+INSERT INTO files (id, created_at, expires_at, filename, content_length)
+ SELECT id, created_at, expires_at, filename, content_length
+ FROM files_non_strict;
+
+DROP TABLE files_non_strict;
diff --git a/src/db/mod.rs b/src/db/mod.rs
new file mode 100644
index 0000000..37769d2
--- /dev/null
+++ b/src/db/mod.rs
@@ -0,0 +1,47 @@
+use anyhow::{Error, Result};
+use rusqlite_migration::{Migrations, M};
+use tokio_rusqlite::Connection;
+
+pub mod files;
+mod utils;
+
+pub async fn init(path: &str) -> Result<Connection> {
+ let connection = Connection::open(path)
+ .await
+ .map_err(|err| Error::msg(format!("Error opening connection: {err}")))?;
+
+ apply_migrations(&connection).await?;
+ set_pragma(&connection, "foreign_keys", "ON").await?;
+ set_pragma(&connection, "journal_mode", "wal").await?;
+ Ok(connection)
+}
+
+async fn apply_migrations(conn: &Connection) -> Result<()> {
+ let migrations = Migrations::new(vec![
+ M::up(include_str!("migrations/01-init.sql")),
+ M::up(include_str!("migrations/02-strict-tables.sql")),
+ ]);
+
+ Ok(conn
+ .call(move |conn| {
+ migrations
+ .to_latest(conn)
+ .map_err(|migration_err| tokio_rusqlite::Error::Other(Box::new(migration_err)))
+ })
+ .await?)
+}
+
+async fn set_pragma(
+ conn: &Connection,
+ key: impl Into<String>,
+ value: impl Into<String>,
+) -> Result<()> {
+ let key = key.into();
+ let value = value.into();
+ Ok(conn
+ .call(move |conn| {
+ conn.pragma_update(None, &key, &value)
+ .map_err(tokio_rusqlite::Error::Rusqlite)
+ })
+ .await?)
+}
diff --git a/src/db/utils.rs b/src/db/utils.rs
new file mode 100644
index 0000000..0b3e029
--- /dev/null
+++ b/src/db/utils.rs
@@ -0,0 +1,3 @@
+pub fn rusqlite_other_error(msg: impl Into<String>) -> tokio_rusqlite::Error {
+ tokio_rusqlite::Error::Other(msg.into().into())
+}