add cache headers for file downloads

This commit is contained in:
neri 2022-06-30 01:04:03 +02:00
parent 24fc93cef7
commit 36b9096325
1 changed files with 49 additions and 12 deletions

View File

@ -1,17 +1,19 @@
use std::{path::PathBuf, str::FromStr}; use std::{path::PathBuf, str::FromStr, time::SystemTime};
use actix_files::NamedFile; use actix_files::NamedFile;
use actix_web::{ use actix_web::{
error, error,
http::header::{ http::header::{
Accept, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, Accept, CacheControl, CacheDirective, Charset, ContentDisposition, DispositionParam,
Header, DispositionType, Expires, ExtendedValue, Header, HeaderValue, HttpDate, TryIntoHeaderValue,
ACCEPT, CACHE_CONTROL, EXPIRES, VARY,
}, },
web, Error, HttpRequest, HttpResponse, web, Error, HttpRequest, HttpResponse,
}; };
use mime::{Mime, TEXT_HTML}; use mime::{Mime, TEXT_HTML};
use sqlx::postgres::PgPool; use sqlx::postgres::PgPool;
use std::path::Path; use std::path::Path;
use time::OffsetDateTime;
use tokio::fs; use tokio::fs;
use url::Url; use url::Url;
@ -35,16 +37,19 @@ pub async fn download(
config: web::Data<Config>, config: web::Data<Config>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let id = req.match_info().query("id"); let id = req.match_info().query("id");
let (file_id, file_name, file_kind, delete) = load_file_info(id, &db).await?; let (file_id, file_name, valid_till, file_kind, delete) = load_file_info(id, &db).await?;
let mut path = config.files_dir.clone(); let mut path = config.files_dir.clone();
path.push(&file_id); path.push(&file_id);
let file_mime = get_content_type(&path); let file_mime = get_content_type(&path);
let response = match get_view_type(&req, &file_kind, &file_mime, &path, delete).await { let mut response = match get_view_type(&req, &file_kind, &file_mime, &path, delete).await {
ViewType::Raw => build_file_response(false, &file_name, path, file_mime, &req), ViewType::Raw => build_file_response(false, &file_name, path, file_mime, &req).await,
ViewType::Download => build_file_response(true, &file_name, path, file_mime, &req), ViewType::Download => build_file_response(true, &file_name, path, file_mime, &req).await,
ViewType::Html => build_text_response(&path).await, ViewType::Html => build_text_response(&path).await,
}; }?;
insert_cache_headers(&mut response, valid_till);
if delete { if delete {
deleter::delete_by_id(&db, &file_id, &config.files_dir) deleter::delete_by_id(&db, &file_id, &config.files_dir)
.await .await
@ -53,15 +58,16 @@ pub async fn download(
error::ErrorInternalServerError("could not delete file") error::ErrorInternalServerError("could not delete file")
})?; })?;
} }
response
Ok(response)
} }
async fn load_file_info( async fn load_file_info(
id: &str, id: &str,
db: &web::Data<sqlx::Pool<sqlx::Postgres>>, db: &web::Data<sqlx::Pool<sqlx::Postgres>>,
) -> Result<(String, String, String, bool), Error> { ) -> Result<(String, String, OffsetDateTime, String, bool), Error> {
sqlx::query_as( sqlx::query_as(
"SELECT file_id, file_name, kind, delete_on_download from files WHERE file_id = $1", "SELECT file_id, file_name, valid_till, kind, delete_on_download from files WHERE file_id = $1",
) )
.bind(id) .bind(id)
.fetch_optional(db.as_ref()) .fetch_optional(db.as_ref())
@ -147,7 +153,7 @@ async fn build_text_response(path: &Path) -> Result<HttpResponse, Error> {
.body(html)) .body(html))
} }
fn build_file_response( async fn build_file_response(
download: bool, download: bool,
file_name: &str, file_name: &str,
path: PathBuf, path: PathBuf,
@ -169,6 +175,7 @@ fn build_file_response(
})? })?
.set_content_type(content_type) .set_content_type(content_type)
.set_content_disposition(content_disposition); .set_content_disposition(content_disposition);
Ok(file.into_response(req)) Ok(file.into_response(req))
} }
@ -183,3 +190,33 @@ fn get_disposition_params(filename: &str) -> Vec<DispositionParam> {
} }
parameters parameters
} }
fn insert_cache_headers(response: &mut HttpResponse, valid_till: OffsetDateTime) {
if response.status().is_success() {
let valid_duration = valid_till - OffsetDateTime::now_utc();
let valid_cache_seconds = valid_duration.whole_seconds().clamp(0, u32::MAX as i64) as u32;
response.headers_mut().insert(
CACHE_CONTROL,
CacheControl(vec![
CacheDirective::Public,
CacheDirective::MustRevalidate,
CacheDirective::MaxAge(valid_cache_seconds), // todo: expiry in seconds
CacheDirective::NoTransform,
CacheDirective::Extension("immutable".to_owned(), None),
])
.try_into_value()
.unwrap(),
);
response.headers_mut().insert(
EXPIRES,
Expires(HttpDate::from(
SystemTime::now() + std::time::Duration::from_secs(valid_cache_seconds.into()),
))
.try_into_value()
.unwrap(),
);
}
response
.headers_mut()
.insert(VARY, HeaderValue::from_name(ACCEPT));
}