Files
FunConnect/server/src/auth_middleware.rs

60 lines
1.7 KiB
Rust
Raw Normal View History

use axum::{
extract::{FromRef, FromRequestParts},
http::{header, StatusCode},
};
use std::sync::Arc;
use uuid::Uuid;
use crate::AppState;
use crate::api::auth::verify_access_token;
#[derive(Debug, Clone)]
pub struct AuthUser {
pub user_id: Uuid,
}
impl<S> FromRequestParts<S> for AuthUser
where
S: Send + Sync,
Arc<AppState>: FromRef<S>,
{
type Rejection = (StatusCode, axum::Json<serde_json::Value>);
fn from_request_parts<'life0, 'life1, 'async_trait>(
parts: &'life0 mut axum::http::request::Parts,
state: &'life1 S,
) -> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = Result<Self, Self::Rejection>>
+ ::core::marker::Send
+ 'async_trait,
>>
where
'life0: 'async_trait,
'life1: 'async_trait,
Self: 'async_trait,
{
Box::pin(async move {
let state = Arc::<AppState>::from_ref(state);
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match auth_header {
Some(token) => match verify_access_token(token, &state.jwt_secret) {
Ok(claims) => Ok(AuthUser { user_id: claims.sub }),
Err(_) => Err((
StatusCode::UNAUTHORIZED,
axum::Json(serde_json::json!({"error": "invalid or expired token"})),
)),
},
None => Err((
StatusCode::UNAUTHORIZED,
axum::Json(serde_json::json!({"error": "missing authorization header"})),
)),
}
})
}
}