diff --git a/Cargo.lock b/Cargo.lock index cb16dc2..98c5677 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -523,7 +523,7 @@ dependencies = [ [[package]] name = "datatrash" -version = "2.5.0" +version = "2.5.1" dependencies = [ "actix-files", "actix-governor", diff --git a/Cargo.toml b/Cargo.toml index 5782308..66b8b95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "datatrash" -version = "2.5.0" +version = "2.5.1" authors = ["neri"] edition = "2021" diff --git a/src/download.rs b/src/download.rs index f020c64..6a02be2 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, str::FromStr, time::SystemTime}; +use std::{io::ErrorKind, time::SystemTime}; use actix_files::NamedFile; use actix_web::{ @@ -16,7 +16,12 @@ use std::path::Path; use time::OffsetDateTime; use tokio::fs; -use crate::{config::Config, deleter, mime_relations, template}; +use crate::{ + config::Config, + deleter, + file_info::{self, FileInfo}, + mime_relations, template, +}; const TEXT_VIEW_SIZE_LIMIT: u64 = 512 * 1024; // 512KiB @@ -32,24 +37,27 @@ pub async fn download( config: web::Data, ) -> Result { let id = req.match_info().query("id"); - let (file_id, file_name, valid_till, content_type, delete) = load_file_info(id, &db).await?; - let path = config.files_dir.join(&file_id); + let file_info = file_info::find_by_id(id, &db) + .await + .map_err(|db_err| { + log::error!("could not run select statement {:?}", db_err); + error::ErrorInternalServerError("could not run select statement") + })? + .ok_or_else(|| error::ErrorNotFound("file does not exist or has expired"))?; + let delete = file_info.delete_on_download; + let valid_till = file_info.valid_till; + let path = config.files_dir.join(&file_info.file_id); - let mime = Mime::from_str(&content_type).unwrap_or(APPLICATION_OCTET_STREAM); - let computed_file_name = file_name.clone().unwrap_or_else(|| { - let extension = mime_relations::get_extension(&mime).unwrap_or("txt"); - format!("{file_id}.{extension}") - }); - let mut response = match get_view_type(&req, &mime, &path, delete).await { - ViewType::Raw => build_file_response(false, &computed_file_name, path, mime, &req), - ViewType::Download => build_file_response(true, &computed_file_name, path, mime, &req), - ViewType::Html => build_html_response(file_name.as_deref(), &path, &config, &req).await, + let mut response = match get_view_type(&file_info, &path, &req).await { + ViewType::Raw => build_file_response(file_info, &path, false, &req), + ViewType::Download => build_file_response(file_info, &path, true, &req), + ViewType::Html => build_html_response(file_info, &path, &config, &req).await, }?; insert_cache_headers(&mut response, valid_till); if delete { - deleter::delete_by_id(&db, &file_id, &config.files_dir) + deleter::delete_by_id(&db, id, &config.files_dir) .await .map_err(|db_err| { log::error!("could not delete file {:?}", db_err); @@ -60,36 +68,14 @@ pub async fn download( Ok(response) } -async fn load_file_info( - id: &str, - db: &web::Data>, -) -> Result<(String, Option, OffsetDateTime, String, bool), Error> { - sqlx::query_as( - "SELECT file_id, file_name, valid_till, content_type, delete_on_download from files WHERE file_id = $1", - ) - .bind(id) - .fetch_optional(db.as_ref()) - .await - .map_err(|db_err| { - log::error!("could not run select statement {:?}", db_err); - error::ErrorInternalServerError("could not run select statement") - })? - .ok_or_else(|| error::ErrorNotFound("file does not exist or has expired")) -} - -async fn get_view_type( - req: &HttpRequest, - mime: &Mime, - file_path: &Path, - delete_on_download: bool, -) -> ViewType { - if delete_on_download || req.query_string().contains("dl") { +async fn get_view_type(file_info: &FileInfo, file_path: &Path, req: &HttpRequest) -> ViewType { + if file_info.delete_on_download || req.query_string().contains("dl") { return ViewType::Download; } if req.query_string().contains("raw") { return ViewType::Raw; } - if !mime_relations::matches_text(mime) { + if !mime_relations::matches_text(&file_info.content_type) { return ViewType::Raw; } if get_file_size(file_path).await >= TEXT_VIEW_SIZE_LIMIT { @@ -100,7 +86,7 @@ async fn get_view_type( if accept_mime == TEXT_HTML { return ViewType::Html; } - if mime_matches(&accept_mime, mime) { + if mime_matches(&accept_mime, &file_info.content_type) { break; } } @@ -122,28 +108,42 @@ async fn get_file_size(file_path: &Path) -> u64 { } async fn build_html_response( - file_name: Option<&str>, + file_info: FileInfo, path: &Path, config: &Config, req: &HttpRequest, ) -> Result { - let content = fs::read_to_string(path).await.map_err(|file_err| { - log::error!("file could not be read {:?}", file_err); - error::ErrorInternalServerError("this file should be here but could not be found") - })?; - let html_view = template::build_html_view_template(&content, file_name, req, config); + let content = match fs::read_to_string(path).await { + Ok(content) => content, + Err(file_err) if file_err.kind() == ErrorKind::InvalidData => { + // content may not be valid UTF-8, try to return the raw file instead + return build_file_response(file_info, path, false, req); + } + Err(file_err) => { + log::error!("file could not be read: {:?}", file_err); + return Err(error::ErrorInternalServerError( + "this file should be here but could not be found", + )); + } + }; + let html_view = + template::build_html_view_template(&content, file_info.file_name.as_deref(), req, config); Ok(HttpResponse::Ok() .content_type(TEXT_HTML.to_string()) .body(html_view)) } fn build_file_response( + file_info: FileInfo, + path: &Path, download: bool, - file_name: &str, - path: PathBuf, - mime: Mime, req: &HttpRequest, ) -> Result { + let file_name: String = file_info.file_name.unwrap_or_else(|| { + let file_id = file_info.file_id; + let extension = mime_relations::get_extension(&file_info.content_type).unwrap_or("txt"); + format!("{file_id}.{extension}") + }); let content_disposition = ContentDisposition { disposition: if download { DispositionType::Attachment @@ -154,10 +154,10 @@ fn build_file_response( }; let file = NamedFile::open(path) .map_err(|file_err| { - log::error!("file could not be read {:?}", file_err); + log::error!("file could not be read: {:?}", file_err); error::ErrorInternalServerError("this file should be here but could not be found") })? - .set_content_type(mime) + .set_content_type(file_info.content_type) .set_content_disposition(content_disposition); let mut response = file.into_response(req); @@ -165,16 +165,19 @@ fn build_file_response( Ok(response) } -fn get_disposition_params(filename: &str) -> Vec { - let mut parameters = vec![DispositionParam::Filename(filename.to_owned())]; +fn get_disposition_params(filename: String) -> Vec { if !filename.is_ascii() { - parameters.push(DispositionParam::FilenameExt(ExtendedValue { - charset: Charset::Ext(String::from("UTF-8")), - language_tag: None, - value: filename.to_owned().into_bytes(), - })); + vec![ + DispositionParam::Filename(filename.clone()), + DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: filename.into_bytes(), + }), + ] + } else { + vec![DispositionParam::Filename(filename)] } - parameters } const ALLOWED_CONTEXTS: [&str; 6] = ["audio", "document", "empty", "font", "image", "video"]; @@ -209,7 +212,7 @@ fn insert_cache_headers(response: &mut HttpResponse, valid_till: OffsetDateTime) CacheControl(vec![ CacheDirective::Public, CacheDirective::MustRevalidate, - CacheDirective::MaxAge(valid_cache_seconds), // todo: expiry in seconds + CacheDirective::MaxAge(valid_cache_seconds), CacheDirective::NoTransform, CacheDirective::Extension("immutable".to_owned(), None), ]) diff --git a/src/file_info.rs b/src/file_info.rs new file mode 100644 index 0000000..35ad601 --- /dev/null +++ b/src/file_info.rs @@ -0,0 +1,69 @@ +use mime::{Mime, APPLICATION_OCTET_STREAM}; +use sqlx::{postgres::PgRow, FromRow, Row}; +use time::OffsetDateTime; + +use crate::multipart::UploadConfig; + +pub struct FileInfo { + pub file_id: String, + pub file_name: Option, + pub valid_till: OffsetDateTime, + pub content_type: Mime, + pub delete_on_download: bool, +} + +impl FileInfo { + pub fn new(file_id: String, upload_config: UploadConfig) -> Self { + Self { + file_id, + file_name: upload_config.original_name, + valid_till: upload_config.valid_till, + content_type: upload_config.content_type, + delete_on_download: upload_config.delete_on_download, + } + } +} + +pub async fn find_by_id( + id: &str, + db: &sqlx::Pool, +) -> Result, sqlx::Error> { + sqlx::query_as( + "SELECT file_id, file_name, valid_till, content_type, delete_on_download from files WHERE file_id = $1", + ) + .bind(id) + .fetch_optional(db) + .await +} + +pub async fn create(file_info: &FileInfo, db: &sqlx::Pool) -> Result<(), sqlx::Error> { + sqlx::query( + "INSERT INTO Files (file_id, file_name, content_type, valid_till, delete_on_download) \ + VALUES ($1, $2, $3, $4, $5)", + ) + .bind(&file_info.file_id) + .bind(&file_info.file_name) + .bind(file_info.content_type.to_string()) + .bind(file_info.valid_till) + .bind(file_info.delete_on_download) + .execute(db) + .await + .map(|_| ()) +} + +impl FromRow<'_, PgRow> for FileInfo { + fn from_row(row: &'_ PgRow) -> Result { + Ok(Self { + file_id: row.try_get("file_id")?, + file_name: row.try_get("file_name")?, + valid_till: row.try_get("valid_till")?, + content_type: row + .try_get_raw("content_type")? + .as_str() + .map_err(sqlx::Error::Decode)? + .parse() + .unwrap_or(APPLICATION_OCTET_STREAM), + delete_on_download: row.try_get("delete_on_download")?, + }) + } +} diff --git a/src/main.rs b/src/main.rs index f724487..97eb5b2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod config; mod db; mod deleter; mod download; +mod file_info; mod mime_relations; mod multipart; mod rate_limit; @@ -53,8 +54,6 @@ async fn main() -> std::io::Result<()> { let config = config::from_env().await; let (sender, receiver) = channel(8); - log::info!("omnomnom"); - let db = web::Data::new(pool.clone()); let expiry_watch_sender = web::Data::new(sender); let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| "0.0.0.0:8000".to_owned()); @@ -78,6 +77,9 @@ async fn main() -> std::io::Result<()> { .finish() .unwrap(); + log::info!("Listening on {bind_address}"); + log::info!("omnomnom"); + let http_server = HttpServer::new({ move || { App::new() diff --git a/src/upload.rs b/src/upload.rs index 670b472..e9242dc 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -1,8 +1,8 @@ use std::io::ErrorKind; use crate::config::Config; -use crate::multipart::UploadConfig; -use crate::{multipart, template}; +use crate::file_info::FileInfo; +use crate::{file_info, multipart, template}; use actix_files::NamedFile; use actix_multipart::Multipart; use actix_web::http::header::LOCATION; @@ -33,52 +33,30 @@ pub async fn upload( expiry_watch_sender: web::Data>, config: web::Data, ) -> Result { - let (file_id, file_path) = create_unique_file(&config).await.map_err(|file_err| { + let (file_id, file_path) = create_unique_file_id(&config).await.map_err(|file_err| { log::error!("could not create file {:?}", file_err); error::ErrorInternalServerError("could not create file") })?; - let upload_config = multipart::parse_multipart(payload, &file_path, &config).await?; - let file_name = upload_config.original_name.clone(); + let file_info = FileInfo::new(file_id, upload_config); - insert_file_metadata(&file_id, file_name, &file_path, &upload_config, db).await?; - - log::info!( - "create new file {} (valid_till: {}, content_type: {}, delete_on_download: {})", - file_id, - upload_config.valid_till, - upload_config.content_type, - upload_config.delete_on_download - ); + create_file(&file_info, &file_path, db).await?; expiry_watch_sender.send(()).await.unwrap(); - let redirect = get_redirect_url(&file_id, upload_config.original_name.as_deref()); - let url = template::get_file_url(&req, &file_id, upload_config.original_name.as_deref()); + let redirect = get_redirect_url(&file_info.file_id, file_info.file_name.as_deref()); + let url = template::get_file_url(&req, &file_info.file_id, file_info.file_name.as_deref()); Ok(HttpResponse::SeeOther() .insert_header((LOCATION, redirect)) .body(format!("{url}\n"))) } -async fn insert_file_metadata( - file_id: &String, - file_name: Option, +async fn create_file( + file_info: &FileInfo, file_path: &Path, - upload_config: &UploadConfig, db: web::Data>, ) -> Result<(), Error> { - let db_insert = sqlx::query( - "INSERT INTO Files (file_id, file_name, content_type, valid_till, delete_on_download) \ - VALUES ($1, $2, $3, $4, $5)", - ) - .bind(file_id) - .bind(file_name) - .bind(&upload_config.content_type.to_string()) - .bind(upload_config.valid_till) - .bind(upload_config.delete_on_download) - .execute(db.as_ref()) - .await; - if let Err(db_err) = db_insert { + if let Err(db_err) = file_info::create(file_info, &db).await { log::error!("could not insert into datebase {:?}", db_err); if let Err(file_err) = fs::remove_file(file_path).await { @@ -88,10 +66,17 @@ async fn insert_file_metadata( "could not insert file into database", )); } + log::info!( + "create new file {} (valid_till: {}, content_type: {}, delete_on_download: {})", + file_info.file_id, + file_info.valid_till, + file_info.content_type, + file_info.delete_on_download + ); Ok(()) } -async fn create_unique_file( +async fn create_unique_file_id( config: &web::Data, ) -> Result<(String, PathBuf), std::io::Error> { loop {