use chrono::{DateTime, Local}; use tokio_rusqlite::{params, Connection, Result}; use crate::model::{decode_datetime, encode_datetime, File}; pub async fn insert_file(conn: &Connection, file: File) -> Result<()> { conn.call(move |conn| { conn.execute( r#" INSERT INTO files(id, created_at, expires_at, filename, content_length) VALUES (?1, datetime(), ?2, ?3, ?4) "#, params![ file.id, encode_datetime(file.expires_at), file.name, file.content_length ], ) .map_err(tokio_rusqlite::Error::Rusqlite) }) .await .map(|_| ()) } pub async fn get_file(conn: &Connection, file_id: String) -> Result> { conn.call(move |conn| { let mut stmt = conn.prepare( r#" SELECT filename, expires_at, content_length FROM files WHERE id = ? AND expires_at > datetime() "#, )?; let mut iter = stmt.query_map([file_id.clone()], |row| { let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); Ok(res) })?; match iter.next() { Some(Ok((filename, expires_at, content_length))) => { match decode_datetime(&expires_at) { Some(expires_at) => Ok(Some(File { id: file_id.clone(), name: filename, expires_at, content_length, })), _ => Err(rusqlite_other_error(&format!( "Error decoding datetime: {expires_at}" ))), } } Some(_) => Err(rusqlite_other_error("Error reading file in DB")), None => Ok(None), } }) .await } pub async fn list_expire_after(conn: &Connection, time: DateTime) -> Result> { conn.call(move |conn| { let mut stmt = conn.prepare( r#" SELECT id FROM files WHERE expires_at > ? "#, )?; let iter = stmt.query_map([encode_datetime(time)], |row| row.get(0))?; let mut res = vec![]; for id in iter { res.push(id?) } Ok(res) }) .await } pub async fn remove_expire_before(conn: &Connection, time: DateTime) -> Result<()> { conn.call(move |conn| { conn.execute( &format!( r#" DELETE FROM files WHERE expires_at <= ? "# ), [encode_datetime(time)] )?; Ok(()) }) .await } fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { tokio_rusqlite::Error::Other(msg.into()) } #[cfg(test)] mod tests { use super::*; use crate::model::{generate_file_id, local_time}; use chrono::Duration; use std::collections::HashSet; use std::ops::Add; #[tokio::test] async fn test_insert_and_get_file() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); assert!(insert_file(&conn, file.clone()).await.is_ok()); let file_res = get_file(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert_eq!(file_res.unwrap(), Some(file)); } #[tokio::test] async fn test_expired_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::zero()); assert!(insert_file(&conn, file.clone()).await.is_ok()); let file_res = get_file(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } #[tokio::test] async fn test_wrong_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); assert!(insert_file(&conn, file.clone()).await.is_ok()); let file_res = get_file(&conn, "wrong-id".to_string()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } #[tokio::test] async fn test_list_non_expirable() { let conn = get_connection().await; let file_expire = dummy_file(Duration::zero()); let file_no_expire_1 = dummy_file(Duration::minutes(1)); let file_no_expire_2 = dummy_file(Duration::minutes(1)); assert!(insert_file(&conn, file_expire.clone()).await.is_ok()); assert!(insert_file(&conn, file_no_expire_1.clone()).await.is_ok()); assert!(insert_file(&conn, file_no_expire_2.clone()).await.is_ok()); let list = list_expire_after(&conn, Local::now()).await; assert!(list.is_ok()); assert_eq!( HashSet::from_iter(list.unwrap().iter()), HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) ) } #[tokio::test] async fn test_remove_expire_before() { let conn = get_connection().await; let file_1 = dummy_file(Duration::zero()); let file_2 = dummy_file(Duration::zero()); let file_3 = dummy_file(Duration::minutes(1)); let file_4 = dummy_file(Duration::minutes(1)); assert!(insert_file(&conn, file_1.clone()).await.is_ok()); assert!(insert_file(&conn, file_2.clone()).await.is_ok()); assert!(insert_file(&conn, file_3.clone()).await.is_ok()); assert!(insert_file(&conn, file_4.clone()).await.is_ok()); assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); assert!(get_file(&conn, file_1.id).await.unwrap().is_none()); assert!(get_file(&conn, file_2.id).await.unwrap().is_none()); assert!(get_file(&conn, file_3.id).await.unwrap().is_some()); assert!(get_file(&conn, file_4.id).await.unwrap().is_some()); } fn dummy_file(td: Duration) -> File { File { id: generate_file_id(), name: "foo".to_string(), expires_at: local_time().add(td), content_length: 100, } } async fn get_connection() -> Connection { let conn = Connection::open_in_memory().await.unwrap(); let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap(); let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await; assert!(res.is_ok()); conn } }