From 5e2aee9248a00c8b213a8e07e4796d668bff519c Mon Sep 17 00:00:00 2001 From: Joris Date: Fri, 7 Feb 2025 09:17:26 +0100 Subject: Apply SQL migrations at startup --- src/db.rs | 205 --------------------------------- src/db/files.rs | 193 +++++++++++++++++++++++++++++++ src/db/migrations/01-init.sql | 7 ++ src/db/migrations/02-strict-tables.sql | 15 +++ src/db/mod.rs | 47 ++++++++ src/db/utils.rs | 3 + src/jobs.rs | 14 +-- src/main.rs | 8 +- src/routes.rs | 74 +++++++----- 9 files changed, 317 insertions(+), 249 deletions(-) delete mode 100644 src/db.rs create mode 100644 src/db/files.rs create mode 100644 src/db/migrations/01-init.sql create mode 100644 src/db/migrations/02-strict-tables.sql create mode 100644 src/db/mod.rs create mode 100644 src/db/utils.rs (limited to 'src') diff --git a/src/db.rs b/src/db.rs deleted file mode 100644 index ab699c6..0000000 --- a/src/db.rs +++ /dev/null @@ -1,205 +0,0 @@ -use chrono::{DateTime, Local}; -use tokio_rusqlite::{params, Connection, Result}; - -use crate::model::{decode_datetime, encode_datetime, File}; - -pub async fn insert_file(conn: &Connection, file: File) -> Result<()> { - conn.call(move |conn| { - conn.execute( - r#" - INSERT INTO - files(id, created_at, expires_at, filename, content_length) - VALUES - (?1, datetime(), ?2, ?3, ?4) - "#, - params![ - file.id, - encode_datetime(file.expires_at), - file.name, - file.content_length - ], - ) - .map_err(tokio_rusqlite::Error::Rusqlite) - }) - .await - .map(|_| ()) -} - -pub async fn get_file(conn: &Connection, file_id: String) -> Result> { - conn.call(move |conn| { - let mut stmt = conn.prepare( - r#" - SELECT - filename, expires_at, content_length - FROM - files - WHERE - id = ? - AND expires_at > datetime() - "#, - )?; - - let mut iter = stmt.query_map([file_id.clone()], |row| { - let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); - Ok(res) - })?; - - match iter.next() { - Some(Ok((filename, expires_at, content_length))) => { - match decode_datetime(&expires_at) { - Some(expires_at) => Ok(Some(File { - id: file_id.clone(), - name: filename, - expires_at, - content_length, - })), - _ => Err(rusqlite_other_error(&format!( - "Error decoding datetime: {expires_at}" - ))), - } - } - Some(_) => Err(rusqlite_other_error("Error reading file in DB")), - None => Ok(None), - } - }) - .await -} - -pub async fn list_expire_after(conn: &Connection, time: DateTime) -> Result> { - conn.call(move |conn| { - let mut stmt = conn.prepare( - r#" - SELECT - id - FROM - files - WHERE - expires_at > ? - "#, - )?; - - let iter = stmt.query_map([encode_datetime(time)], |row| row.get(0))?; - - let mut res = vec![]; - for id in iter { - res.push(id?) - } - Ok(res) - }) - .await -} - -pub async fn remove_expire_before(conn: &Connection, time: DateTime) -> Result<()> { - conn.call(move |conn| { - conn.execute( - &format!( - r#" - DELETE FROM - files - WHERE - expires_at <= ? - "# - ), - [encode_datetime(time)] - )?; - - Ok(()) - }) - .await -} - -fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { - tokio_rusqlite::Error::Other(msg.into()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::model::{generate_file_id, local_time}; - use chrono::Duration; - use std::collections::HashSet; - use std::ops::Add; - - #[tokio::test] - async fn test_insert_and_get_file() { - let conn = get_connection().await; - let file = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file.clone()).await.is_ok()); - let file_res = get_file(&conn, file.id.clone()).await; - assert!(file_res.is_ok()); - assert_eq!(file_res.unwrap(), Some(file)); - } - - #[tokio::test] - async fn test_expired_file_err() { - let conn = get_connection().await; - let file = dummy_file(Duration::zero()); - assert!(insert_file(&conn, file.clone()).await.is_ok()); - let file_res = get_file(&conn, file.id.clone()).await; - assert!(file_res.is_ok()); - assert!(file_res.unwrap().is_none()); - } - - #[tokio::test] - async fn test_wrong_file_err() { - let conn = get_connection().await; - let file = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file.clone()).await.is_ok()); - let file_res = get_file(&conn, "wrong-id".to_string()).await; - assert!(file_res.is_ok()); - assert!(file_res.unwrap().is_none()); - } - - #[tokio::test] - async fn test_list_non_expirable() { - let conn = get_connection().await; - let file_expire = dummy_file(Duration::zero()); - let file_no_expire_1 = dummy_file(Duration::minutes(1)); - let file_no_expire_2 = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file_expire.clone()).await.is_ok()); - assert!(insert_file(&conn, file_no_expire_1.clone()).await.is_ok()); - assert!(insert_file(&conn, file_no_expire_2.clone()).await.is_ok()); - let list = list_expire_after(&conn, Local::now()).await; - assert!(list.is_ok()); - assert_eq!( - HashSet::from_iter(list.unwrap().iter()), - HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) - ) - } - - #[tokio::test] - async fn test_remove_expire_before() { - let conn = get_connection().await; - let file_1 = dummy_file(Duration::zero()); - let file_2 = dummy_file(Duration::zero()); - let file_3 = dummy_file(Duration::minutes(1)); - let file_4 = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file_1.clone()).await.is_ok()); - assert!(insert_file(&conn, file_2.clone()).await.is_ok()); - assert!(insert_file(&conn, file_3.clone()).await.is_ok()); - assert!(insert_file(&conn, file_4.clone()).await.is_ok()); - assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); - assert!(get_file(&conn, file_1.id).await.unwrap().is_none()); - assert!(get_file(&conn, file_2.id).await.unwrap().is_none()); - assert!(get_file(&conn, file_3.id).await.unwrap().is_some()); - assert!(get_file(&conn, file_4.id).await.unwrap().is_some()); - } - - fn dummy_file(td: Duration) -> File { - File { - id: generate_file_id(), - name: "foo".to_string(), - expires_at: local_time().add(td), - content_length: 100, - } - } - - async fn get_connection() -> Connection { - let conn = Connection::open_in_memory().await.unwrap(); - let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap(); - - let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await; - assert!(res.is_ok()); - conn - } -} diff --git a/src/db/files.rs b/src/db/files.rs new file mode 100644 index 0000000..995a2ed --- /dev/null +++ b/src/db/files.rs @@ -0,0 +1,193 @@ +use chrono::{DateTime, Local}; +use tokio_rusqlite::{named_params, Connection, Result}; + +use crate::db::utils; +use crate::model::{decode_datetime, encode_datetime, File}; + +pub async fn insert(conn: &Connection, file: File) -> Result<()> { + let query = r#" + INSERT INTO files(id, created_at, expires_at, filename, content_length) + VALUES (:id, datetime(), :expires_at, :name, :content_length) + "#; + + conn.call(move |conn| { + conn.execute( + query, + named_params![ + ":id": file.id, + ":expires_at": encode_datetime(file.expires_at), + ":name": file.name, + ":content_length": file.content_length + ], + ) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await + .map(|_| ()) +} + +pub async fn get(conn: &Connection, file_id: impl Into) -> Result> { + let file_id = file_id.into(); + + let query = r#" + SELECT filename, expires_at, content_length + FROM files + WHERE + id = :id + AND expires_at > datetime() + "#; + + conn.call(move |conn| { + let mut stmt = conn.prepare(query)?; + + let mut iter = stmt.query_map(named_params![":id": file_id], |row| { + let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); + Ok(res) + })?; + + match iter.next() { + Some(Ok((filename, expires_at, content_length))) => { + match decode_datetime(&expires_at) { + Some(expires_at) => Ok(Some(File { + id: file_id, + name: filename, + expires_at, + content_length, + })), + _ => Err(utils::rusqlite_other_error(format!( + "Error decoding datetime: {expires_at}" + ))), + } + } + Some(_) => Err(utils::rusqlite_other_error("Error reading file in DB")), + None => Ok(None), + } + }) + .await +} + +pub async fn list_expire_after(conn: &Connection, time: DateTime) -> Result> { + let query = r#" + SELECT id + FROM files + WHERE expires_at > :expires_at + "#; + + conn.call(move |conn| { + let mut stmt = conn.prepare(query)?; + + let iter = stmt.query_map(named_params![":expires_at": encode_datetime(time)], |row| { + row.get(0) + })?; + + let mut res = vec![]; + for id in iter { + res.push(id?) + } + Ok(res) + }) + .await +} + +pub async fn remove_expire_before(conn: &Connection, time: DateTime) -> Result<()> { + let query = r#" + DELETE FROM files + WHERE expires_at <= :expires_at + "#; + + conn.call(move |conn| { + conn.execute(query, named_params![":expires_at": encode_datetime(time)])?; + Ok(()) + }) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::{generate_file_id, local_time}; + use chrono::Duration; + use std::collections::HashSet; + use std::ops::Add; + + #[tokio::test] + async fn test_insert_and_get_file() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert_eq!(file_res.unwrap(), Some(file)); + } + + #[tokio::test] + async fn test_expired_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::zero()); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_wrong_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, "wrong-id".to_string()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_list_non_expirable() { + let conn = get_connection().await; + let file_expire = dummy_file(Duration::zero()); + let file_no_expire_1 = dummy_file(Duration::minutes(1)); + let file_no_expire_2 = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file_expire.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_1.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_2.clone()).await.is_ok()); + let list = list_expire_after(&conn, Local::now()).await; + assert!(list.is_ok()); + assert_eq!( + HashSet::from_iter(list.unwrap().iter()), + HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) + ) + } + + #[tokio::test] + async fn test_remove_expire_before() { + let conn = get_connection().await; + let file_1 = dummy_file(Duration::zero()); + let file_2 = dummy_file(Duration::zero()); + let file_3 = dummy_file(Duration::minutes(1)); + let file_4 = dummy_file(Duration::minutes(1)); + assert!(insert(&conn, file_1.clone()).await.is_ok()); + assert!(insert(&conn, file_2.clone()).await.is_ok()); + assert!(insert(&conn, file_3.clone()).await.is_ok()); + assert!(insert(&conn, file_4.clone()).await.is_ok()); + assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); + assert!(get(&conn, file_1.id).await.unwrap().is_none()); + assert!(get(&conn, file_2.id).await.unwrap().is_none()); + assert!(get(&conn, file_3.id).await.unwrap().is_some()); + assert!(get(&conn, file_4.id).await.unwrap().is_some()); + } + + fn dummy_file(td: Duration) -> File { + File { + id: generate_file_id(), + name: "foo".to_string(), + expires_at: local_time().add(td), + content_length: 100, + } + } + + async fn get_connection() -> Connection { + let conn = Connection::open_in_memory().await.unwrap(); + let res = crate::db::apply_migrations(&conn).await; + assert!(res.is_ok()); + conn + } +} diff --git a/src/db/migrations/01-init.sql b/src/db/migrations/01-init.sql new file mode 100644 index 0000000..75abc54 --- /dev/null +++ b/src/db/migrations/01-init.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS "files" ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + filename TEXT NOT NULL, + content_length INTEGER NOT NULL +); diff --git a/src/db/migrations/02-strict-tables.sql b/src/db/migrations/02-strict-tables.sql new file mode 100644 index 0000000..433ef39 --- /dev/null +++ b/src/db/migrations/02-strict-tables.sql @@ -0,0 +1,15 @@ +ALTER TABLE "files" RENAME TO "files_non_strict"; + +CREATE TABLE IF NOT EXISTS "files" ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + filename TEXT NOT NULL, + content_length INTEGER NOT NULL +) STRICT; + +INSERT INTO files (id, created_at, expires_at, filename, content_length) + SELECT id, created_at, expires_at, filename, content_length + FROM files_non_strict; + +DROP TABLE files_non_strict; diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..37769d2 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,47 @@ +use anyhow::{Error, Result}; +use rusqlite_migration::{Migrations, M}; +use tokio_rusqlite::Connection; + +pub mod files; +mod utils; + +pub async fn init(path: &str) -> Result { + let connection = Connection::open(path) + .await + .map_err(|err| Error::msg(format!("Error opening connection: {err}")))?; + + apply_migrations(&connection).await?; + set_pragma(&connection, "foreign_keys", "ON").await?; + set_pragma(&connection, "journal_mode", "wal").await?; + Ok(connection) +} + +async fn apply_migrations(conn: &Connection) -> Result<()> { + let migrations = Migrations::new(vec![ + M::up(include_str!("migrations/01-init.sql")), + M::up(include_str!("migrations/02-strict-tables.sql")), + ]); + + Ok(conn + .call(move |conn| { + migrations + .to_latest(conn) + .map_err(|migration_err| tokio_rusqlite::Error::Other(Box::new(migration_err))) + }) + .await?) +} + +async fn set_pragma( + conn: &Connection, + key: impl Into, + value: impl Into, +) -> Result<()> { + let key = key.into(); + let value = value.into(); + Ok(conn + .call(move |conn| { + conn.pragma_update(None, &key, &value) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await?) +} diff --git a/src/db/utils.rs b/src/db/utils.rs new file mode 100644 index 0000000..0b3e029 --- /dev/null +++ b/src/db/utils.rs @@ -0,0 +1,3 @@ +pub fn rusqlite_other_error(msg: impl Into) -> tokio_rusqlite::Error { + tokio_rusqlite::Error::Other(msg.into().into()) +} diff --git a/src/jobs.rs b/src/jobs.rs index a01a70d..cac7660 100644 --- a/src/jobs.rs +++ b/src/jobs.rs @@ -18,19 +18,19 @@ pub async fn start(db_conn: Connection, files_dir: String) { } } -async fn cleanup_expired(db_conn: &Connection, files_dir: &String) { +async fn cleanup_expired(db_conn: &Connection, files_dir: &str) { let time = Local::now(); match read_dir(files_dir).await { Err(msg) => log::error!("Listing files: {msg}"), - Ok(files) => match db::list_expire_after(db_conn, time).await { + Ok(files) => match db::files::list_expire_after(db_conn, time).await { Err(msg) => log::error!("Getting non expirable files: {msg}"), Ok(non_expirable) => { let non_expirable = HashSet::::from_iter(non_expirable.iter().cloned()); let expired_ids = files.difference(&non_expirable); let count = remove_files(files_dir, expired_ids.cloned()).await; log::info!("Removed {} files", count); - if let Err(msg) = db::remove_expire_before(db_conn, time).await { + if let Err(msg) = db::files::remove_expire_before(db_conn, time).await { log::error!("Removing files: {msg}") } } @@ -38,7 +38,7 @@ async fn cleanup_expired(db_conn: &Connection, files_dir: &String) { } } -async fn read_dir(files_dir: &String) -> Result, String> { +async fn read_dir(files_dir: &str) -> Result, String> { match fs::read_dir(files_dir).await { Err(msg) => Err(msg.to_string()), Ok(mut read_dir) => { @@ -61,16 +61,16 @@ async fn read_dir(files_dir: &String) -> Result, String> { } } -async fn remove_files(files_dir: &String, ids: I) -> i32 +async fn remove_files(files_dir: &str, ids: I) -> i32 where I: Iterator, { let mut count = 0; for id in ids { - let path = Path::new(&files_dir).join(id.clone()); + let path = Path::new(files_dir).join(id.clone()); match fs::remove_file(path).await { Err(msg) => log::error!("Removing file: {msg}"), - Ok(_) => count += 1 + Ok(_) => count += 1, } } count diff --git a/src/main.rs b/src/main.rs index b2af6de..0a7bd56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,11 @@ use std::env; use std::net::SocketAddr; +use anyhow::Result; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper_util::rt::TokioIo; use tokio::net::TcpListener; -use tokio_rusqlite::Connection; mod db; mod jobs; @@ -15,7 +15,7 @@ mod templates; mod util; #[tokio::main] -async fn main() -> std::result::Result<(), Box> { +async fn main() -> Result<()> { env_logger::init(); let host = get_env("HOST"); @@ -24,9 +24,7 @@ async fn main() -> std::result::Result<(), Box> { let authorized_key = get_env("KEY"); let files_dir = get_env("FILES_DIR"); - let db_conn = Connection::open(db_path) - .await - .expect("Error while openning DB conection"); + let db_conn = db::init(&db_path).await?; tokio::spawn(jobs::start(db_conn.clone(), files_dir.clone())); diff --git a/src/routes.rs b/src/routes.rs index c48ec24..15c255e 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -27,22 +27,23 @@ pub async fn routes( let files_dir = Path::new(&files_dir); match (request.method(), path) { - (&Method::GET, [""]) => { - Ok(with_headers( - response(StatusCode::OK, templates::INDEX.to_string()), - vec![(CONTENT_TYPE, "text/html")], - )) - }, + (&Method::GET, [""]) => Ok(with_headers( + response(StatusCode::OK, templates::INDEX), + vec![(CONTENT_TYPE, "text/html")], + )), (&Method::GET, ["static", "main.js"]) => Ok(static_file( - include_str!("static/main.js").to_string(), + include_str!("static/main.js"), "application/javascript", )), - (&Method::GET, ["static", "main.css"]) => Ok(static_file( - include_str!("static/main.css").to_string(), - "text/css", - )), - (&Method::POST, ["upload"]) => upload_file(request, db_conn, authorized_key, files_dir).await, - (&Method::GET, ["share", file_id]) => get(db_conn, file_id, GetFile::ShowPage, files_dir).await, + (&Method::GET, ["static", "main.css"]) => { + Ok(static_file(include_str!("static/main.css"), "text/css")) + } + (&Method::POST, ["upload"]) => { + upload_file(request, db_conn, authorized_key, files_dir).await + } + (&Method::GET, ["share", file_id]) => { + get(db_conn, file_id, GetFile::ShowPage, files_dir).await + } (&Method::GET, ["share", file_id, "download"]) => { get(db_conn, file_id, GetFile::Download, files_dir).await } @@ -59,10 +60,7 @@ async fn upload_file( let key = get_header(&request, "X-Key"); if key != Some(authorized_key) { log::info!("Unauthorized file upload"); - Ok(response( - StatusCode::UNAUTHORIZED, - "Unauthorized".to_string(), - )) + Ok(response(StatusCode::UNAUTHORIZED, "Unauthorized")) } else { let file_id = model::generate_file_id(); let filename = get_header(&request, "X-Filename").map(|s| util::sanitize_filename(&s)); @@ -92,7 +90,7 @@ async fn upload_file( content_length, }; - match db::insert_file(&db_conn, file.clone()).await { + match db::files::insert(&db_conn, file.clone()).await { Ok(_) => Ok(response(StatusCode::OK, file_id)), Err(msg) => { log::error!("Insert file: {msg}"); @@ -128,14 +126,12 @@ async fn get( get_file: GetFile, files_dir: &Path, ) -> Result>> { - let file = db::get_file(&db_conn, file_id.to_string()).await; + let file = db::files::get(&db_conn, file_id).await; match (get_file, file) { - (GetFile::ShowPage, Ok(Some(file))) => { - Ok(with_headers( - response(StatusCode::OK, templates::file_page(file)), - vec![(CONTENT_TYPE, "text/html")], - )) - } + (GetFile::ShowPage, Ok(Some(file))) => Ok(with_headers( + response(StatusCode::OK, templates::file_page(file)), + vec![(CONTENT_TYPE, "text/html")], + )), (GetFile::Download, Ok(Some(file))) => { let path = files_dir.join(file_id); Ok(stream_file(path, file).await) @@ -148,17 +144,31 @@ async fn get( } } -fn static_file(text: String, content_type: &str) -> Response> { +fn static_file( + text: impl Into, + content_type: &str, +) -> Response> { let response = Response::builder() - .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .body( + Full::new(text.into().into()) + .map_err(|e| match e {}) + .boxed(), + ) .unwrap(); with_headers(response, vec![(CONTENT_TYPE, content_type)]) } -fn response(status_code: StatusCode, text: String) -> Response> { +fn response( + status_code: StatusCode, + text: impl Into, +) -> Response> { Response::builder() .status(status_code) - .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .body( + Full::new(text.into().into()) + .map_err(|e| match e {}) + .boxed(), + ) .unwrap() } @@ -190,17 +200,17 @@ async fn stream_file(path: PathBuf, file: model::File) -> Response Response> { - response(StatusCode::NOT_FOUND, templates::NOT_FOUND.to_string()) + response(StatusCode::NOT_FOUND, templates::NOT_FOUND) } fn bad_request() -> Response> { - response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST.to_string()) + response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST) } fn internal_server_error() -> Response> { response( StatusCode::INTERNAL_SERVER_ERROR, - templates::INTERNAL_SERVER_ERROR.to_string(), + templates::INTERNAL_SERVER_ERROR, ) } -- cgit v1.2.3