diff options
Diffstat (limited to 'src/db/files.rs')
-rw-r--r-- | src/db/files.rs | 193 |
1 files changed, 193 insertions, 0 deletions
diff --git a/src/db/files.rs b/src/db/files.rs new file mode 100644 index 0000000..995a2ed --- /dev/null +++ b/src/db/files.rs @@ -0,0 +1,193 @@ +use chrono::{DateTime, Local}; +use tokio_rusqlite::{named_params, Connection, Result}; + +use crate::db::utils; +use crate::model::{decode_datetime, encode_datetime, File}; + +pub async fn insert(conn: &Connection, file: File) -> Result<()> { + let query = r#" + INSERT INTO files(id, created_at, expires_at, filename, content_length) + VALUES (:id, datetime(), :expires_at, :name, :content_length) + "#; + + conn.call(move |conn| { + conn.execute( + query, + named_params![ + ":id": file.id, + ":expires_at": encode_datetime(file.expires_at), + ":name": file.name, + ":content_length": file.content_length + ], + ) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await + .map(|_| ()) +} + +pub async fn get(conn: &Connection, file_id: impl Into<String>) -> Result<Option<File>> { + let file_id = file_id.into(); + + let query = r#" + SELECT filename, expires_at, content_length + FROM files + WHERE + id = :id + AND expires_at > datetime() + "#; + + conn.call(move |conn| { + let mut stmt = conn.prepare(query)?; + + let mut iter = stmt.query_map(named_params![":id": file_id], |row| { + let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); + Ok(res) + })?; + + match iter.next() { + Some(Ok((filename, expires_at, content_length))) => { + match decode_datetime(&expires_at) { + Some(expires_at) => Ok(Some(File { + id: file_id, + name: filename, + expires_at, + content_length, + })), + _ => Err(utils::rusqlite_other_error(format!( + "Error decoding datetime: {expires_at}" + ))), + } + } + Some(_) => Err(utils::rusqlite_other_error("Error reading file in DB")), + None => Ok(None), + } + }) + .await +} + +pub async fn list_expire_after(conn: &Connection, time: DateTime<Local>) -> Result<Vec<String>> { + let query = r#" + SELECT id + FROM files + WHERE expires_at > :expires_at + "#; + + conn.call(move |conn| { + let mut stmt = conn.prepare(query)?; + + let iter = stmt.query_map(named_params![":expires_at": encode_datetime(time)], |row| { + row.get(0) + })?; + + let mut res = vec![]; + for id in iter { + res.push(id?) + } + Ok(res) + }) + .await +} + +pub async fn remove_expire_before(conn: &Connection, time: DateTime<Local>) -> Result<()> { + let query = r#" + DELETE FROM files + WHERE expires_at <= :expires_at + "#; + + conn.call(move |conn| { + conn.execute(query, named_params![":expires_at": encode_datetime(time)])?; + Ok(()) + }) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::{generate_file_id, local_time}; + use chrono::Duration; + use std::collections::HashSet; + use std::ops::Add; + + #[tokio::test] + async fn test_insert_and_get_file() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert_eq!(file_res.unwrap(), Some(file)); + } + + #[tokio::test] + async fn test_expired_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::zero()); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_wrong_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, "wrong-id".to_string()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_list_non_expirable() { + let conn = get_connection().await; + let file_expire = dummy_file(Duration::zero()); + let file_no_expire_1 = dummy_file(Duration::minutes(1)); + let file_no_expire_2 = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file_expire.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_1.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_2.clone()).await.is_ok()); + let list = list_expire_after(&conn, Local::now()).await; + assert!(list.is_ok()); + assert_eq!( + HashSet::from_iter(list.unwrap().iter()), + HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) + ) + } + + #[tokio::test] + async fn test_remove_expire_before() { + let conn = get_connection().await; + let file_1 = dummy_file(Duration::zero()); + let file_2 = dummy_file(Duration::zero()); + let file_3 = dummy_file(Duration::minutes(1)); + let file_4 = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file_1.clone()).await.is_ok()); + assert!(insert(&conn, file_2.clone()).await.is_ok()); + assert!(insert(&conn, file_3.clone()).await.is_ok()); + assert!(insert(&conn, file_4.clone()).await.is_ok()); + assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); + assert!(get(&conn, file_1.id).await.unwrap().is_none()); + assert!(get(&conn, file_2.id).await.unwrap().is_none()); + assert!(get(&conn, file_3.id).await.unwrap().is_some()); + assert!(get(&conn, file_4.id).await.unwrap().is_some()); + } + + fn dummy_file(td: Duration) -> File { + File { + id: generate_file_id(), + name: "foo".to_string(), + expires_at: local_time().add(td), + content_length: 100, + } + } + + async fn get_connection() -> Connection { + let conn = Connection::open_in_memory().await.unwrap(); + let res = crate::db::apply_migrations(&conn).await; + assert!(res.is_ok()); + conn + } +} |