diff options
author | Joris | 2024-06-02 14:38:13 +0200 |
---|---|---|
committer | Joris | 2024-06-02 14:38:22 +0200 |
commit | 1019ea1ed341e3a7769c046aa0be5764789360b6 (patch) | |
tree | 1a0d8a4f00cff252d661c42fc23ed4c19795da6f /src/db.rs | |
parent | e8da9790dc6d55cd2e8883322cdf9a7bf5b4f5b7 (diff) |
Migrate to Rust and Hyper
With sanic, downloading a file locally is around ten times slower than
with Rust and hyper.
Maybe `pypy` could have helped, but I didn’t succeed to set it up
quickly with the dependencies.
Diffstat (limited to 'src/db.rs')
-rw-r--r-- | src/db.rs | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..e1bb7e3 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,125 @@ +use tokio_rusqlite::{params, Connection, Result}; + +use crate::model::{decode_datetime, encode_datetime, File}; + +pub async fn insert_file(conn: &Connection, file: File) -> Result<()> { + conn.call(move |conn| { + conn.execute( + r#" + INSERT INTO + files(id, created_at, expires_at, filename, content_length) + VALUES + (?1, datetime(), ?2, ?3, ?4) + "#, + params![ + file.id, + encode_datetime(file.expires_at), + file.name, + file.content_length + ], + ) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await + .map(|_| ()) +} + +pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File>> { + conn.call(move |conn| { + let mut stmt = conn.prepare( + r#" + SELECT + filename, expires_at, content_length + FROM + files + WHERE + id = ? + AND expires_at > datetime() + "#, + )?; + + let mut iter = stmt.query_map([file_id.clone()], |row| { + let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); + Ok(res) + })?; + + match iter.next() { + Some(Ok((filename, expires_at, content_length))) => { + match decode_datetime(&expires_at) { + Some(expires_at) => Ok(Some(File { + id: file_id.clone(), + name: filename, + expires_at, + content_length, + })), + _ => Err(rusqlite_other_error(&format!( + "Error decoding datetime: {expires_at}" + ))), + } + } + Some(_) => Err(rusqlite_other_error("Error reading file in DB")), + None => Ok(None), + } + }) + .await +} + +fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { + tokio_rusqlite::Error::Other(msg.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::local_time; + use chrono::Duration; + use std::ops::Add; + + #[tokio::test] + async fn test_insert_and_get_file() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file.clone()).await.is_ok()); + let file_res = get_file(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert_eq!(file_res.unwrap(), Some(file)); + } + + #[tokio::test] + async fn test_expired_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::zero()); + assert!(insert_file(&conn, file.clone()).await.is_ok()); + let file_res = get_file(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_wrong_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file.clone()).await.is_ok()); + let file_res = get_file(&conn, "wrong-id".to_string()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + fn dummy_file(td: Duration) -> File { + File { + id: "1234".to_string(), + name: "foo".to_string(), + expires_at: local_time().add(td), + content_length: 100, + } + } + + async fn get_connection() -> Connection { + let conn = Connection::open_in_memory().await.unwrap(); + let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap(); + + let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await; + assert!(res.is_ok()); + conn + } +} |