Merge 005df63be7
into 58cebbac0b
This commit is contained in:
commit
a885052a5b
|
@ -0,0 +1,106 @@
|
|||
use rocket::{
|
||||
async_trait,
|
||||
data::{Data, FromData, Outcome},
|
||||
form::{Error, Form, FromForm},
|
||||
http::Status,
|
||||
Request,
|
||||
};
|
||||
|
||||
use crate::CsrfToken;
|
||||
|
||||
pub struct CsrfForm<T>(T);
|
||||
|
||||
impl<T> std::ops::Deref for CsrfForm<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CsrfError<E> {
|
||||
CSRFTokenInvalid,
|
||||
Other(E),
|
||||
}
|
||||
|
||||
struct CsrfTokenForm<'r, T> {
|
||||
token: &'r str,
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<'r, T: FromForm<'r>> FromForm<'r> for CsrfTokenForm<'r, T> {
|
||||
type Context = (Option<&'r str>, T::Context);
|
||||
fn init(opts: rocket::form::Options) -> Self::Context {
|
||||
(None, T::init(opts))
|
||||
}
|
||||
|
||||
fn push_value(ctxt: &mut Self::Context, field: rocket::form::ValueField<'r>) {
|
||||
if field.name == "csrf_token" {
|
||||
ctxt.0 = Some(field.value);
|
||||
} else {
|
||||
T::push_value(&mut ctxt.1, field);
|
||||
}
|
||||
}
|
||||
|
||||
fn push_data<'life0, 'life1, 'async_trait>(
|
||||
ctxt: &'life0 mut Self::Context,
|
||||
field: rocket::form::DataField<'r, 'life1>,
|
||||
) -> core::pin::Pin<
|
||||
Box<dyn core::future::Future<Output = ()> + core::marker::Send + 'async_trait>,
|
||||
>
|
||||
where
|
||||
'r: 'async_trait,
|
||||
'life0: 'async_trait,
|
||||
'life1: 'async_trait,
|
||||
Self: 'async_trait,
|
||||
{
|
||||
T::push_data(&mut ctxt.1, field)
|
||||
}
|
||||
|
||||
fn finalize(ctxt: Self::Context) -> rocket::form::Result<'r, Self> {
|
||||
let inner = T::finalize(ctxt.1)?;
|
||||
if let Some(token) = ctxt.0 {
|
||||
Ok(Self { token, inner })
|
||||
} else {
|
||||
Err(Error::validation("csrf_token is required").into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'r, T: FromForm<'r>> FromData<'r> for CsrfForm<T> {
|
||||
type Error = CsrfError<<Form<T> as FromData<'r>>::Error>;
|
||||
async fn from_data(r: &'r Request<'_>, d: Data<'r>) -> Outcome<'r, Self> {
|
||||
use rocket::outcome::Outcome::*;
|
||||
let token: CsrfToken = match r.guard().await {
|
||||
Success(t) => t,
|
||||
Failure((s, _e)) => return Outcome::Failure((s, CsrfError::CSRFTokenInvalid)),
|
||||
Forward(()) => return Outcome::Forward(d),
|
||||
};
|
||||
// Bypass token in form fields if header is set
|
||||
if let Some(header) = r.headers().get_one("X-CSRF-Token") {
|
||||
if token.verify(header).is_ok() {
|
||||
let form: Form<T> = match Form::from_data(r, d).await {
|
||||
Success(t) => t,
|
||||
Failure((s, e)) => return Outcome::Failure((s, CsrfError::Other(e))),
|
||||
Forward(d) => return Outcome::Forward(d),
|
||||
};
|
||||
Outcome::Success(Self(form.into_inner()))
|
||||
} else {
|
||||
Outcome::Failure((Status::NotAcceptable, CsrfError::CSRFTokenInvalid))
|
||||
}
|
||||
} else {
|
||||
let form: Form<CsrfTokenForm<T>> = match Form::from_data(r, d).await {
|
||||
Success(t) => t,
|
||||
Failure((s, e)) => return Outcome::Failure((s, CsrfError::Other(e))),
|
||||
Forward(d) => return Outcome::Forward(d),
|
||||
};
|
||||
if token.verify(form.token).is_ok() {
|
||||
Outcome::Success(Self(form.into_inner().inner))
|
||||
} else {
|
||||
Outcome::Failure((Status::NotAcceptable, CsrfError::CSRFTokenInvalid))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
17
src/lib.rs
17
src/lib.rs
|
@ -2,14 +2,19 @@ use bcrypt::{hash, verify};
|
|||
use rand::{distributions::Standard, Rng};
|
||||
use rocket::{
|
||||
async_trait,
|
||||
data::FromData,
|
||||
fairing::{self, Fairing as RocketFairing, Info, Kind},
|
||||
form::{Form, FromForm},
|
||||
http::{Cookie, Status},
|
||||
request::{FromRequest, Outcome},
|
||||
sentinel::resolution::DefaultSentinel,
|
||||
time::{Duration, OffsetDateTime},
|
||||
Data, Request, Rocket, State,
|
||||
Data, Request, Rocket, Sentinel, State,
|
||||
};
|
||||
use std::borrow::Cow;
|
||||
|
||||
pub mod form;
|
||||
|
||||
const BCRYPT_COST: u32 = 8;
|
||||
|
||||
const _PARAM_NAME: &str = "authenticity_token";
|
||||
|
@ -86,7 +91,7 @@ impl CsrfToken {
|
|||
hash(&self.0, BCRYPT_COST).unwrap()
|
||||
}
|
||||
|
||||
pub fn verify(&self, form_authenticity_token: &String) -> Result<(), VerificationFailure> {
|
||||
pub fn verify(&self, form_authenticity_token: &str) -> Result<(), VerificationFailure> {
|
||||
if verify(&self.0, form_authenticity_token).unwrap_or(false) {
|
||||
Ok(())
|
||||
} else {
|
||||
|
@ -146,6 +151,14 @@ impl<'r> FromRequest<'r> for CsrfToken {
|
|||
}
|
||||
}
|
||||
|
||||
// Implement Sentinel for CsrfToken to require CsrfConfig to be attached
|
||||
impl Sentinel for CsrfToken {
|
||||
fn abort(rocket: &Rocket<rocket::Ignite>) -> bool {
|
||||
// Delegate to `&State<CsrfConfig>`
|
||||
State::<CsrfConfig>::abort(rocket)
|
||||
}
|
||||
}
|
||||
|
||||
trait RequestCsrf {
|
||||
fn valid_csrf_token_from_session(&self, config: &CsrfConfig) -> Option<Vec<u8>> {
|
||||
self.csrf_token_from_session(config).and_then(|raw| {
|
||||
|
|
Loading…
Reference in New Issue