diff options
author | Joris | 2025-02-07 09:17:26 +0100 |
---|---|---|
committer | Joris | 2025-02-07 09:17:26 +0100 |
commit | 5e2aee9248a00c8b213a8e07e4796d668bff519c (patch) | |
tree | 861fdc7f2b23cc8538cf479f44f62e700a77222f /src | |
parent | 463d58b37909c976d3f30bdb5a652f0e8a018b55 (diff) |
Apply SQL migrations at startup
Diffstat (limited to 'src')
-rw-r--r-- | src/db/files.rs (renamed from src/db.rs) | 144 | ||||
-rw-r--r-- | src/db/migrations/01-init.sql | 7 | ||||
-rw-r--r-- | src/db/migrations/02-strict-tables.sql | 15 | ||||
-rw-r--r-- | src/db/mod.rs | 47 | ||||
-rw-r--r-- | src/db/utils.rs | 3 | ||||
-rw-r--r-- | src/jobs.rs | 14 | ||||
-rw-r--r-- | src/main.rs | 8 | ||||
-rw-r--r-- | src/routes.rs | 74 |
8 files changed, 190 insertions, 122 deletions
diff --git a/src/db.rs b/src/db/files.rs index ab699c6..995a2ed 100644 --- a/src/db.rs +++ b/src/db/files.rs @@ -1,22 +1,23 @@ use chrono::{DateTime, Local}; -use tokio_rusqlite::{params, Connection, Result}; +use tokio_rusqlite::{named_params, Connection, Result}; +use crate::db::utils; use crate::model::{decode_datetime, encode_datetime, File}; -pub async fn insert_file(conn: &Connection, file: File) -> Result<()> { +pub async fn insert(conn: &Connection, file: File) -> Result<()> { + let query = r#" + INSERT INTO files(id, created_at, expires_at, filename, content_length) + VALUES (:id, datetime(), :expires_at, :name, :content_length) + "#; + conn.call(move |conn| { conn.execute( - r#" - INSERT INTO - files(id, created_at, expires_at, filename, content_length) - VALUES - (?1, datetime(), ?2, ?3, ?4) - "#, - params![ - file.id, - encode_datetime(file.expires_at), - file.name, - file.content_length + query, + named_params![ + ":id": file.id, + ":expires_at": encode_datetime(file.expires_at), + ":name": file.name, + ":content_length": file.content_length ], ) .map_err(tokio_rusqlite::Error::Rusqlite) @@ -25,21 +26,21 @@ pub async fn insert_file(conn: &Connection, file: File) -> Result<()> { .map(|_| ()) } -pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File>> { +pub async fn get(conn: &Connection, file_id: impl Into<String>) -> Result<Option<File>> { + let file_id = file_id.into(); + + let query = r#" + SELECT filename, expires_at, content_length + FROM files + WHERE + id = :id + AND expires_at > datetime() + "#; + conn.call(move |conn| { - let mut stmt = conn.prepare( - r#" - SELECT - filename, expires_at, content_length - FROM - files - WHERE - id = ? - AND expires_at > datetime() - "#, - )?; - - let mut iter = stmt.query_map([file_id.clone()], |row| { + let mut stmt = conn.prepare(query)?; + + let mut iter = stmt.query_map(named_params![":id": file_id], |row| { let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); Ok(res) })?; @@ -48,17 +49,17 @@ pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File> Some(Ok((filename, expires_at, content_length))) => { match decode_datetime(&expires_at) { Some(expires_at) => Ok(Some(File { - id: file_id.clone(), + id: file_id, name: filename, expires_at, content_length, })), - _ => Err(rusqlite_other_error(&format!( + _ => Err(utils::rusqlite_other_error(format!( "Error decoding datetime: {expires_at}" ))), } } - Some(_) => Err(rusqlite_other_error("Error reading file in DB")), + Some(_) => Err(utils::rusqlite_other_error("Error reading file in DB")), None => Ok(None), } }) @@ -66,19 +67,18 @@ pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File> } pub async fn list_expire_after(conn: &Connection, time: DateTime<Local>) -> Result<Vec<String>> { + let query = r#" + SELECT id + FROM files + WHERE expires_at > :expires_at + "#; + conn.call(move |conn| { - let mut stmt = conn.prepare( - r#" - SELECT - id - FROM - files - WHERE - expires_at > ? - "#, - )?; - - let iter = stmt.query_map([encode_datetime(time)], |row| row.get(0))?; + let mut stmt = conn.prepare(query)?; + + let iter = stmt.query_map(named_params![":expires_at": encode_datetime(time)], |row| { + row.get(0) + })?; let mut res = vec![]; for id in iter { @@ -90,28 +90,18 @@ pub async fn list_expire_after(conn: &Connection, time: DateTime<Local>) -> Resu } pub async fn remove_expire_before(conn: &Connection, time: DateTime<Local>) -> Result<()> { - conn.call(move |conn| { - conn.execute( - &format!( - r#" - DELETE FROM - files - WHERE - expires_at <= ? - "# - ), - [encode_datetime(time)] - )?; + let query = r#" + DELETE FROM files + WHERE expires_at <= :expires_at + "#; + conn.call(move |conn| { + conn.execute(query, named_params![":expires_at": encode_datetime(time)])?; Ok(()) }) .await } -fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { - tokio_rusqlite::Error::Other(msg.into()) -} - #[cfg(test)] mod tests { use super::*; @@ -124,8 +114,8 @@ mod tests { async fn test_insert_and_get_file() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file.clone()).await.is_ok()); - let file_res = get_file(&conn, file.id.clone()).await; + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert_eq!(file_res.unwrap(), Some(file)); } @@ -134,8 +124,8 @@ mod tests { async fn test_expired_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::zero()); - assert!(insert_file(&conn, file.clone()).await.is_ok()); - let file_res = get_file(&conn, file.id.clone()).await; + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } @@ -144,8 +134,8 @@ mod tests { async fn test_wrong_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file.clone()).await.is_ok()); - let file_res = get_file(&conn, "wrong-id".to_string()).await; + assert!(insert(&conn, file.clone()).await.is_ok()); + let file_res = get(&conn, "wrong-id".to_string()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } @@ -156,9 +146,9 @@ mod tests { let file_expire = dummy_file(Duration::zero()); let file_no_expire_1 = dummy_file(Duration::minutes(1)); let file_no_expire_2 = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file_expire.clone()).await.is_ok()); - assert!(insert_file(&conn, file_no_expire_1.clone()).await.is_ok()); - assert!(insert_file(&conn, file_no_expire_2.clone()).await.is_ok()); + assert!(insert(&conn, file_expire.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_1.clone()).await.is_ok()); + assert!(insert(&conn, file_no_expire_2.clone()).await.is_ok()); let list = list_expire_after(&conn, Local::now()).await; assert!(list.is_ok()); assert_eq!( @@ -174,15 +164,15 @@ mod tests { let file_2 = dummy_file(Duration::zero()); let file_3 = dummy_file(Duration::minutes(1)); let file_4 = dummy_file(Duration::minutes(1)); - assert!(insert_file(&conn, file_1.clone()).await.is_ok()); - assert!(insert_file(&conn, file_2.clone()).await.is_ok()); - assert!(insert_file(&conn, file_3.clone()).await.is_ok()); - assert!(insert_file(&conn, file_4.clone()).await.is_ok()); + assert!(insert(&conn, file_1.clone()).await.is_ok()); + assert!(insert(&conn, file_2.clone()).await.is_ok()); + assert!(insert(&conn, file_3.clone()).await.is_ok()); + assert!(insert(&conn, file_4.clone()).await.is_ok()); assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); - assert!(get_file(&conn, file_1.id).await.unwrap().is_none()); - assert!(get_file(&conn, file_2.id).await.unwrap().is_none()); - assert!(get_file(&conn, file_3.id).await.unwrap().is_some()); - assert!(get_file(&conn, file_4.id).await.unwrap().is_some()); + assert!(get(&conn, file_1.id).await.unwrap().is_none()); + assert!(get(&conn, file_2.id).await.unwrap().is_none()); + assert!(get(&conn, file_3.id).await.unwrap().is_some()); + assert!(get(&conn, file_4.id).await.unwrap().is_some()); } fn dummy_file(td: Duration) -> File { @@ -196,9 +186,7 @@ mod tests { async fn get_connection() -> Connection { let conn = Connection::open_in_memory().await.unwrap(); - let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap(); - - let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await; + let res = crate::db::apply_migrations(&conn).await; assert!(res.is_ok()); conn } diff --git a/src/db/migrations/01-init.sql b/src/db/migrations/01-init.sql new file mode 100644 index 0000000..75abc54 --- /dev/null +++ b/src/db/migrations/01-init.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS "files" ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + filename TEXT NOT NULL, + content_length INTEGER NOT NULL +); diff --git a/src/db/migrations/02-strict-tables.sql b/src/db/migrations/02-strict-tables.sql new file mode 100644 index 0000000..433ef39 --- /dev/null +++ b/src/db/migrations/02-strict-tables.sql @@ -0,0 +1,15 @@ +ALTER TABLE "files" RENAME TO "files_non_strict"; + +CREATE TABLE IF NOT EXISTS "files" ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + filename TEXT NOT NULL, + content_length INTEGER NOT NULL +) STRICT; + +INSERT INTO files (id, created_at, expires_at, filename, content_length) + SELECT id, created_at, expires_at, filename, content_length + FROM files_non_strict; + +DROP TABLE files_non_strict; diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..37769d2 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,47 @@ +use anyhow::{Error, Result}; +use rusqlite_migration::{Migrations, M}; +use tokio_rusqlite::Connection; + +pub mod files; +mod utils; + +pub async fn init(path: &str) -> Result<Connection> { + let connection = Connection::open(path) + .await + .map_err(|err| Error::msg(format!("Error opening connection: {err}")))?; + + apply_migrations(&connection).await?; + set_pragma(&connection, "foreign_keys", "ON").await?; + set_pragma(&connection, "journal_mode", "wal").await?; + Ok(connection) +} + +async fn apply_migrations(conn: &Connection) -> Result<()> { + let migrations = Migrations::new(vec![ + M::up(include_str!("migrations/01-init.sql")), + M::up(include_str!("migrations/02-strict-tables.sql")), + ]); + + Ok(conn + .call(move |conn| { + migrations + .to_latest(conn) + .map_err(|migration_err| tokio_rusqlite::Error::Other(Box::new(migration_err))) + }) + .await?) +} + +async fn set_pragma( + conn: &Connection, + key: impl Into<String>, + value: impl Into<String>, +) -> Result<()> { + let key = key.into(); + let value = value.into(); + Ok(conn + .call(move |conn| { + conn.pragma_update(None, &key, &value) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await?) +} diff --git a/src/db/utils.rs b/src/db/utils.rs new file mode 100644 index 0000000..0b3e029 --- /dev/null +++ b/src/db/utils.rs @@ -0,0 +1,3 @@ +pub fn rusqlite_other_error(msg: impl Into<String>) -> tokio_rusqlite::Error { + tokio_rusqlite::Error::Other(msg.into().into()) +} diff --git a/src/jobs.rs b/src/jobs.rs index a01a70d..cac7660 100644 --- a/src/jobs.rs +++ b/src/jobs.rs @@ -18,19 +18,19 @@ pub async fn start(db_conn: Connection, files_dir: String) { } } -async fn cleanup_expired(db_conn: &Connection, files_dir: &String) { +async fn cleanup_expired(db_conn: &Connection, files_dir: &str) { let time = Local::now(); match read_dir(files_dir).await { Err(msg) => log::error!("Listing files: {msg}"), - Ok(files) => match db::list_expire_after(db_conn, time).await { + Ok(files) => match db::files::list_expire_after(db_conn, time).await { Err(msg) => log::error!("Getting non expirable files: {msg}"), Ok(non_expirable) => { let non_expirable = HashSet::<String>::from_iter(non_expirable.iter().cloned()); let expired_ids = files.difference(&non_expirable); let count = remove_files(files_dir, expired_ids.cloned()).await; log::info!("Removed {} files", count); - if let Err(msg) = db::remove_expire_before(db_conn, time).await { + if let Err(msg) = db::files::remove_expire_before(db_conn, time).await { log::error!("Removing files: {msg}") } } @@ -38,7 +38,7 @@ async fn cleanup_expired(db_conn: &Connection, files_dir: &String) { } } -async fn read_dir(files_dir: &String) -> Result<HashSet<String>, String> { +async fn read_dir(files_dir: &str) -> Result<HashSet<String>, String> { match fs::read_dir(files_dir).await { Err(msg) => Err(msg.to_string()), Ok(mut read_dir) => { @@ -61,16 +61,16 @@ async fn read_dir(files_dir: &String) -> Result<HashSet<String>, String> { } } -async fn remove_files<I>(files_dir: &String, ids: I) -> i32 +async fn remove_files<I>(files_dir: &str, ids: I) -> i32 where I: Iterator<Item = String>, { let mut count = 0; for id in ids { - let path = Path::new(&files_dir).join(id.clone()); + let path = Path::new(files_dir).join(id.clone()); match fs::remove_file(path).await { Err(msg) => log::error!("Removing file: {msg}"), - Ok(_) => count += 1 + Ok(_) => count += 1, } } count diff --git a/src/main.rs b/src/main.rs index b2af6de..0a7bd56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,11 @@ use std::env; use std::net::SocketAddr; +use anyhow::Result; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper_util::rt::TokioIo; use tokio::net::TcpListener; -use tokio_rusqlite::Connection; mod db; mod jobs; @@ -15,7 +15,7 @@ mod templates; mod util; #[tokio::main] -async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { +async fn main() -> Result<()> { env_logger::init(); let host = get_env("HOST"); @@ -24,9 +24,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { let authorized_key = get_env("KEY"); let files_dir = get_env("FILES_DIR"); - let db_conn = Connection::open(db_path) - .await - .expect("Error while openning DB conection"); + let db_conn = db::init(&db_path).await?; tokio::spawn(jobs::start(db_conn.clone(), files_dir.clone())); diff --git a/src/routes.rs b/src/routes.rs index c48ec24..15c255e 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -27,22 +27,23 @@ pub async fn routes( let files_dir = Path::new(&files_dir); match (request.method(), path) { - (&Method::GET, [""]) => { - Ok(with_headers( - response(StatusCode::OK, templates::INDEX.to_string()), - vec![(CONTENT_TYPE, "text/html")], - )) - }, + (&Method::GET, [""]) => Ok(with_headers( + response(StatusCode::OK, templates::INDEX), + vec![(CONTENT_TYPE, "text/html")], + )), (&Method::GET, ["static", "main.js"]) => Ok(static_file( - include_str!("static/main.js").to_string(), + include_str!("static/main.js"), "application/javascript", )), - (&Method::GET, ["static", "main.css"]) => Ok(static_file( - include_str!("static/main.css").to_string(), - "text/css", - )), - (&Method::POST, ["upload"]) => upload_file(request, db_conn, authorized_key, files_dir).await, - (&Method::GET, ["share", file_id]) => get(db_conn, file_id, GetFile::ShowPage, files_dir).await, + (&Method::GET, ["static", "main.css"]) => { + Ok(static_file(include_str!("static/main.css"), "text/css")) + } + (&Method::POST, ["upload"]) => { + upload_file(request, db_conn, authorized_key, files_dir).await + } + (&Method::GET, ["share", file_id]) => { + get(db_conn, file_id, GetFile::ShowPage, files_dir).await + } (&Method::GET, ["share", file_id, "download"]) => { get(db_conn, file_id, GetFile::Download, files_dir).await } @@ -59,10 +60,7 @@ async fn upload_file( let key = get_header(&request, "X-Key"); if key != Some(authorized_key) { log::info!("Unauthorized file upload"); - Ok(response( - StatusCode::UNAUTHORIZED, - "Unauthorized".to_string(), - )) + Ok(response(StatusCode::UNAUTHORIZED, "Unauthorized")) } else { let file_id = model::generate_file_id(); let filename = get_header(&request, "X-Filename").map(|s| util::sanitize_filename(&s)); @@ -92,7 +90,7 @@ async fn upload_file( content_length, }; - match db::insert_file(&db_conn, file.clone()).await { + match db::files::insert(&db_conn, file.clone()).await { Ok(_) => Ok(response(StatusCode::OK, file_id)), Err(msg) => { log::error!("Insert file: {msg}"); @@ -128,14 +126,12 @@ async fn get( get_file: GetFile, files_dir: &Path, ) -> Result<Response<BoxBody<Bytes, std::io::Error>>> { - let file = db::get_file(&db_conn, file_id.to_string()).await; + let file = db::files::get(&db_conn, file_id).await; match (get_file, file) { - (GetFile::ShowPage, Ok(Some(file))) => { - Ok(with_headers( - response(StatusCode::OK, templates::file_page(file)), - vec![(CONTENT_TYPE, "text/html")], - )) - } + (GetFile::ShowPage, Ok(Some(file))) => Ok(with_headers( + response(StatusCode::OK, templates::file_page(file)), + vec![(CONTENT_TYPE, "text/html")], + )), (GetFile::Download, Ok(Some(file))) => { let path = files_dir.join(file_id); Ok(stream_file(path, file).await) @@ -148,17 +144,31 @@ async fn get( } } -fn static_file(text: String, content_type: &str) -> Response<BoxBody<Bytes, std::io::Error>> { +fn static_file( + text: impl Into<String>, + content_type: &str, +) -> Response<BoxBody<Bytes, std::io::Error>> { let response = Response::builder() - .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .body( + Full::new(text.into().into()) + .map_err(|e| match e {}) + .boxed(), + ) .unwrap(); with_headers(response, vec![(CONTENT_TYPE, content_type)]) } -fn response(status_code: StatusCode, text: String) -> Response<BoxBody<Bytes, std::io::Error>> { +fn response( + status_code: StatusCode, + text: impl Into<String>, +) -> Response<BoxBody<Bytes, std::io::Error>> { Response::builder() .status(status_code) - .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .body( + Full::new(text.into().into()) + .map_err(|e| match e {}) + .boxed(), + ) .unwrap() } @@ -190,17 +200,17 @@ async fn stream_file(path: PathBuf, file: model::File) -> Response<BoxBody<Bytes } fn not_found() -> Response<BoxBody<Bytes, std::io::Error>> { - response(StatusCode::NOT_FOUND, templates::NOT_FOUND.to_string()) + response(StatusCode::NOT_FOUND, templates::NOT_FOUND) } fn bad_request() -> Response<BoxBody<Bytes, std::io::Error>> { - response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST.to_string()) + response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST) } fn internal_server_error() -> Response<BoxBody<Bytes, std::io::Error>> { response( StatusCode::INTERNAL_SERVER_ERROR, - templates::INTERNAL_SERVER_ERROR.to_string(), + templates::INTERNAL_SERVER_ERROR, ) } |