This commit is contained in:
Matthew Pomes 2024-02-08 01:58:45 +04:00 committed by GitHub
commit a885052a5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 2 deletions

106
src/form.rs Normal file
View File

@ -0,0 +1,106 @@
use rocket::{
async_trait,
data::{Data, FromData, Outcome},
form::{Error, Form, FromForm},
http::Status,
Request,
};
use crate::CsrfToken;
pub struct CsrfForm<T>(T);
impl<T> std::ops::Deref for CsrfForm<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub enum CsrfError<E> {
CSRFTokenInvalid,
Other(E),
}
struct CsrfTokenForm<'r, T> {
token: &'r str,
inner: T,
}
impl<'r, T: FromForm<'r>> FromForm<'r> for CsrfTokenForm<'r, T> {
type Context = (Option<&'r str>, T::Context);
fn init(opts: rocket::form::Options) -> Self::Context {
(None, T::init(opts))
}
fn push_value(ctxt: &mut Self::Context, field: rocket::form::ValueField<'r>) {
if field.name == "csrf_token" {
ctxt.0 = Some(field.value);
} else {
T::push_value(&mut ctxt.1, field);
}
}
fn push_data<'life0, 'life1, 'async_trait>(
ctxt: &'life0 mut Self::Context,
field: rocket::form::DataField<'r, 'life1>,
) -> core::pin::Pin<
Box<dyn core::future::Future<Output = ()> + core::marker::Send + 'async_trait>,
>
where
'r: 'async_trait,
'life0: 'async_trait,
'life1: 'async_trait,
Self: 'async_trait,
{
T::push_data(&mut ctxt.1, field)
}
fn finalize(ctxt: Self::Context) -> rocket::form::Result<'r, Self> {
let inner = T::finalize(ctxt.1)?;
if let Some(token) = ctxt.0 {
Ok(Self { token, inner })
} else {
Err(Error::validation("csrf_token is required").into())
}
}
}
#[async_trait]
impl<'r, T: FromForm<'r>> FromData<'r> for CsrfForm<T> {
type Error = CsrfError<<Form<T> as FromData<'r>>::Error>;
async fn from_data(r: &'r Request<'_>, d: Data<'r>) -> Outcome<'r, Self> {
use rocket::outcome::Outcome::*;
let token: CsrfToken = match r.guard().await {
Success(t) => t,
Failure((s, _e)) => return Outcome::Failure((s, CsrfError::CSRFTokenInvalid)),
Forward(()) => return Outcome::Forward(d),
};
// Bypass token in form fields if header is set
if let Some(header) = r.headers().get_one("X-CSRF-Token") {
if token.verify(header).is_ok() {
let form: Form<T> = match Form::from_data(r, d).await {
Success(t) => t,
Failure((s, e)) => return Outcome::Failure((s, CsrfError::Other(e))),
Forward(d) => return Outcome::Forward(d),
};
Outcome::Success(Self(form.into_inner()))
} else {
Outcome::Failure((Status::NotAcceptable, CsrfError::CSRFTokenInvalid))
}
} else {
let form: Form<CsrfTokenForm<T>> = match Form::from_data(r, d).await {
Success(t) => t,
Failure((s, e)) => return Outcome::Failure((s, CsrfError::Other(e))),
Forward(d) => return Outcome::Forward(d),
};
if token.verify(form.token).is_ok() {
Outcome::Success(Self(form.into_inner().inner))
} else {
Outcome::Failure((Status::NotAcceptable, CsrfError::CSRFTokenInvalid))
}
}
}
}

View File

@ -2,14 +2,19 @@ use bcrypt::{hash, verify};
use rand::{distributions::Standard, Rng};
use rocket::{
async_trait,
data::FromData,
fairing::{self, Fairing as RocketFairing, Info, Kind},
form::{Form, FromForm},
http::{Cookie, Status},
request::{FromRequest, Outcome},
sentinel::resolution::DefaultSentinel,
time::{Duration, OffsetDateTime},
Data, Request, Rocket, State,
Data, Request, Rocket, Sentinel, State,
};
use std::borrow::Cow;
pub mod form;
const BCRYPT_COST: u32 = 8;
const _PARAM_NAME: &str = "authenticity_token";
@ -86,7 +91,7 @@ impl CsrfToken {
hash(&self.0, BCRYPT_COST).unwrap()
}
pub fn verify(&self, form_authenticity_token: &String) -> Result<(), VerificationFailure> {
pub fn verify(&self, form_authenticity_token: &str) -> Result<(), VerificationFailure> {
if verify(&self.0, form_authenticity_token).unwrap_or(false) {
Ok(())
} else {
@ -146,6 +151,14 @@ impl<'r> FromRequest<'r> for CsrfToken {
}
}
// Implement Sentinel for CsrfToken to require CsrfConfig to be attached
impl Sentinel for CsrfToken {
fn abort(rocket: &Rocket<rocket::Ignite>) -> bool {
// Delegate to `&State<CsrfConfig>`
State::<CsrfConfig>::abort(rocket)
}
}
trait RequestCsrf {
fn valid_csrf_token_from_session(&self, config: &CsrfConfig) -> Option<Vec<u8>> {
self.csrf_token_from_session(config).and_then(|raw| {