diff --git a/src/csrf.rs b/src/csrf.rs new file mode 100644 index 0000000..34d2d43 --- /dev/null +++ b/src/csrf.rs @@ -0,0 +1,27 @@ +use rocket::{Data, Request}; +use rocket::fairing::{Fairing as RocketFairing, Info, Kind}; + +const COOKIE_NAME: &str = "csrf_token"; + +pub struct Fairing; + +impl Fairing { + pub fn new() -> Self { + Self {} + } +} + +impl RocketFairing for Fairing { + fn info(&self) -> Info { + Info { + name: "CSRF (Cross-Site Request Forgery) protection", + kind: Kind::Request, + } + } + + fn on_request(&self, request: &mut Request, _: &Data) { + let _token: Option = request.cookies() + .get_private(COOKIE_NAME) + .and_then(|cookie| Some(cookie.value().to_string())); + } +} diff --git a/src/main.rs b/src/main.rs index d60d091..709ae9e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ #[cfg(test)] mod tests; +mod csrf; mod config; mod web; mod database; diff --git a/src/web.rs b/src/web.rs index 877c2b9..3423850 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1,3 +1,4 @@ +use crate::csrf; use crate::config; use crate::database; use crate::routes; @@ -10,10 +11,9 @@ pub fn rocket(config: &config::Config) -> Result { let public_path = config.public_path()?; - let secret_key = config.secret_key.as_ref().unwrap().to_string(); - let result = rocket::custom(rocket_config) .manage(database::create_db_pool(config)) + .attach(csrf::Fairing::new()) .attach(Template::fairing()) .mount("/", routes::routes()) .mount("/", StaticFiles::new(public_path, ServeOptions::None));