use chrono::{DateTime, Local}; use tokio_rusqlite::{named_params, Connection, Result}; use crate::db::utils; use crate::model::{decode_datetime, encode_datetime, File}; pub async fn insert(conn: &Connection, file: File) -> Result<()> { let query = r#" INSERT INTO files(id, created_at, expires_at, filename, content_length) VALUES (:id, datetime(), :expires_at, :name, :content_length) "#; conn.call(move |conn| { conn.execute( query, named_params![ ":id": file.id, ":expires_at": encode_datetime(file.expires_at), ":name": file.name, ":content_length": file.content_length ], ) .map_err(tokio_rusqlite::Error::Rusqlite) }) .await .map(|_| ()) } pub async fn get(conn: &Connection, file_id: impl Into) -> Result> { let file_id = file_id.into(); let query = r#" SELECT filename, expires_at, content_length FROM files WHERE id = :id AND expires_at > datetime() "#; conn.call(move |conn| { let mut stmt = conn.prepare(query)?; let mut iter = stmt.query_map(named_params![":id": file_id], |row| { let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); Ok(res) })?; match iter.next() { Some(Ok((filename, expires_at, content_length))) => { match decode_datetime(&expires_at) { Some(expires_at) => Ok(Some(File { id: file_id, name: filename, expires_at, content_length, })), _ => Err(utils::rusqlite_other_error(format!( "Error decoding datetime: {expires_at}" ))), } } Some(_) => Err(utils::rusqlite_other_error("Error reading file in DB")), None => Ok(None), } }) .await } pub async fn list_expire_after(conn: &Connection, time: DateTime) -> Result> { let query = r#" SELECT id FROM files WHERE expires_at > :expires_at "#; conn.call(move |conn| { let mut stmt = conn.prepare(query)?; let iter = stmt.query_map(named_params![":expires_at": encode_datetime(time)], |row| { row.get(0) })?; let mut res = vec![]; for id in iter { res.push(id?) } Ok(res) }) .await } pub async fn remove_expire_before(conn: &Connection, time: DateTime) -> Result<()> { let query = r#" DELETE FROM files WHERE expires_at <= :expires_at "#; conn.call(move |conn| { conn.execute(query, named_params![":expires_at": encode_datetime(time)])?; Ok(()) }) .await } #[cfg(test)] mod tests { use super::*; use crate::model::{generate_file_id, local_time}; use chrono::Duration; use std::collections::HashSet; use std::ops::Add; #[tokio::test] async fn test_insert_and_get_file() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); assert!(insert(&conn, file.clone()).await.is_ok()); let file_res = get(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert_eq!(file_res.unwrap(), Some(file)); } #[tokio::test] async fn test_expired_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::zero()); assert!(insert(&conn, file.clone()).await.is_ok()); let file_res = get(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } #[tokio::test] async fn test_wrong_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); assert!(insert(&conn, file.clone()).await.is_ok()); let file_res = get(&conn, "wrong-id".to_string()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } #[tokio::test] async fn test_list_non_expirable() { let conn = get_connection().await; let file_expire = dummy_file(Duration::zero()); let file_no_expire_1 = dummy_file(Duration::minutes(1)); let file_no_expire_2 = dummy_file(Duration::minutes(1)); assert!(insert(&conn, file_expire.clone()).await.is_ok()); assert!(insert(&conn, file_no_expire_1.clone()).await.is_ok()); assert!(insert(&conn, file_no_expire_2.clone()).await.is_ok()); let list = list_expire_after(&conn, Local::now()).await; assert!(list.is_ok()); assert_eq!( HashSet::from_iter(list.unwrap().iter()), HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) ) } #[tokio::test] async fn test_remove_expire_before() { let conn = get_connection().await; let file_1 = dummy_file(Duration::zero()); let file_2 = dummy_file(Duration::zero()); let file_3 = dummy_file(Duration::minutes(1)); let file_4 = dummy_file(Duration::minutes(1)); assert!(insert(&conn, file_1.clone()).await.is_ok()); assert!(insert(&conn, file_2.clone()).await.is_ok()); assert!(insert(&conn, file_3.clone()).await.is_ok()); assert!(insert(&conn, file_4.clone()).await.is_ok()); assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); assert!(get(&conn, file_1.id).await.unwrap().is_none()); assert!(get(&conn, file_2.id).await.unwrap().is_none()); assert!(get(&conn, file_3.id).await.unwrap().is_some()); assert!(get(&conn, file_4.id).await.unwrap().is_some()); } fn dummy_file(td: Duration) -> File { File { id: generate_file_id(), name: "foo".to_string(), expires_at: local_time().add(td), content_length: 100, } } async fn get_connection() -> Connection { let conn = Connection::open_in_memory().await.unwrap(); let res = crate::db::apply_migrations(&conn).await; assert!(res.is_ok()); conn } }