258 lines
8.5 KiB
Rust
258 lines
8.5 KiB
Rust
use anyhow::Result;
|
|
use argon2::{
|
|
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
|
Argon2,
|
|
};
|
|
use axum::{
|
|
extract::{Json, State},
|
|
http::StatusCode,
|
|
response::IntoResponse,
|
|
};
|
|
use chrono::{Duration, Utc};
|
|
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
|
|
use crate::AppState;
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct Claims {
|
|
pub sub: Uuid,
|
|
pub exp: i64,
|
|
pub iat: i64,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct RegisterRequest {
|
|
pub username: String,
|
|
pub email: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct LoginRequest {
|
|
pub username: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct RefreshRequest {
|
|
pub refresh_token: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct AuthResponse {
|
|
pub access_token: String,
|
|
pub refresh_token: String,
|
|
pub user: UserDto,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct UserDto {
|
|
pub id: Uuid,
|
|
pub username: String,
|
|
pub email: String,
|
|
pub avatar_seed: String,
|
|
}
|
|
|
|
pub fn hash_password(password: &str) -> Result<String> {
|
|
let salt = SaltString::generate(&mut OsRng);
|
|
let argon2 = Argon2::default();
|
|
argon2.hash_password(password.as_bytes(), &salt)
|
|
.map(|h| h.to_string())
|
|
.map_err(|e| anyhow::anyhow!("hash error: {}", e))
|
|
}
|
|
|
|
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
|
|
let parsed_hash = PasswordHash::new(hash)
|
|
.map_err(|e| anyhow::anyhow!("parse hash error: {}", e))?;
|
|
Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
|
|
}
|
|
|
|
pub fn create_access_token(user_id: Uuid, secret: &str) -> Result<String> {
|
|
let now = Utc::now();
|
|
let claims = Claims {
|
|
sub: user_id,
|
|
iat: now.timestamp(),
|
|
exp: (now + Duration::minutes(15)).timestamp(),
|
|
};
|
|
Ok(encode(
|
|
&Header::default(),
|
|
&claims,
|
|
&EncodingKey::from_secret(secret.as_bytes()),
|
|
)?)
|
|
}
|
|
|
|
pub fn verify_access_token(token: &str, secret: &str) -> Result<Claims> {
|
|
let data = decode::<Claims>(
|
|
token,
|
|
&DecodingKey::from_secret(secret.as_bytes()),
|
|
&Validation::new(Algorithm::HS256),
|
|
)?;
|
|
Ok(data.claims)
|
|
}
|
|
|
|
pub async fn register(
|
|
State(state): State<Arc<AppState>>,
|
|
Json(req): Json<RegisterRequest>,
|
|
) -> impl IntoResponse {
|
|
if req.username.len() < 3 || req.username.len() > 32 {
|
|
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "username must be 3-32 chars"}))).into_response();
|
|
}
|
|
if req.password.len() < 8 {
|
|
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "password must be at least 8 chars"}))).into_response();
|
|
}
|
|
|
|
let password_hash = match hash_password(&req.password) {
|
|
Ok(h) => h,
|
|
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "internal error"}))).into_response(),
|
|
};
|
|
|
|
let avatar_seed = Uuid::new_v4().to_string();
|
|
let user_id = Uuid::new_v4();
|
|
|
|
let result = sqlx::query_as::<_, (Uuid, String, String, String)>(
|
|
"INSERT INTO users (id, username, email, password_hash, avatar_seed)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
RETURNING id, username, email, avatar_seed"
|
|
)
|
|
.bind(user_id)
|
|
.bind(&req.username)
|
|
.bind(&req.email)
|
|
.bind(&password_hash)
|
|
.bind(&avatar_seed)
|
|
.fetch_one(&state.db)
|
|
.await;
|
|
|
|
match result {
|
|
Ok((id, username, email, avatar_seed)) => {
|
|
let access_token = create_access_token(id, &state.jwt_secret).unwrap();
|
|
let refresh_token = issue_refresh_token(id, &state).await.unwrap();
|
|
(StatusCode::CREATED, Json(serde_json::json!(AuthResponse {
|
|
access_token,
|
|
refresh_token,
|
|
user: UserDto { id, username, email, avatar_seed },
|
|
}))).into_response()
|
|
}
|
|
Err(e) => {
|
|
if e.to_string().contains("unique") || e.to_string().contains("duplicate") {
|
|
(StatusCode::CONFLICT, Json(serde_json::json!({"error": "username or email already exists"}))).into_response()
|
|
} else {
|
|
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "internal error"}))).into_response()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn login(
|
|
State(state): State<Arc<AppState>>,
|
|
Json(req): Json<LoginRequest>,
|
|
) -> impl IntoResponse {
|
|
let row = sqlx::query_as::<_, (Uuid, String, String, String, String)>(
|
|
"SELECT id, username, email, avatar_seed, password_hash FROM users WHERE username = $1"
|
|
)
|
|
.bind(&req.username)
|
|
.fetch_optional(&state.db)
|
|
.await;
|
|
|
|
match row {
|
|
Ok(Some((id, username, email, avatar_seed, password_hash))) => {
|
|
match verify_password(&req.password, &password_hash) {
|
|
Ok(true) => {
|
|
let _ = sqlx::query("UPDATE users SET last_seen = NOW() WHERE id = $1")
|
|
.bind(id)
|
|
.execute(&state.db)
|
|
.await;
|
|
let access_token = create_access_token(id, &state.jwt_secret).unwrap();
|
|
let refresh_token = issue_refresh_token(id, &state).await.unwrap();
|
|
(StatusCode::OK, Json(serde_json::json!(AuthResponse {
|
|
access_token,
|
|
refresh_token,
|
|
user: UserDto { id, username, email, avatar_seed },
|
|
}))).into_response()
|
|
}
|
|
_ => (StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "invalid credentials"}))).into_response(),
|
|
}
|
|
}
|
|
_ => (StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "invalid credentials"}))).into_response(),
|
|
}
|
|
}
|
|
|
|
pub async fn refresh(
|
|
State(state): State<Arc<AppState>>,
|
|
Json(req): Json<RefreshRequest>,
|
|
) -> impl IntoResponse {
|
|
use sha2::{Digest, Sha256};
|
|
let token_hash = format!("{:x}", Sha256::digest(req.refresh_token.as_bytes()));
|
|
|
|
let row = sqlx::query_as::<_, (Uuid, String, String, String)>(
|
|
r#"SELECT rt.user_id, u.username, u.email, u.avatar_seed
|
|
FROM refresh_tokens rt
|
|
JOIN users u ON u.id = rt.user_id
|
|
WHERE rt.token_hash = $1
|
|
AND rt.revoked_at IS NULL
|
|
AND rt.expires_at > NOW()"#
|
|
)
|
|
.bind(&token_hash)
|
|
.fetch_optional(&state.db)
|
|
.await;
|
|
|
|
match row {
|
|
Ok(Some((user_id, username, email, avatar_seed))) => {
|
|
let _ = sqlx::query("UPDATE refresh_tokens SET revoked_at = NOW() WHERE token_hash = $1")
|
|
.bind(&token_hash)
|
|
.execute(&state.db)
|
|
.await;
|
|
|
|
let access_token = create_access_token(user_id, &state.jwt_secret).unwrap();
|
|
let new_refresh = issue_refresh_token(user_id, &state).await.unwrap();
|
|
(StatusCode::OK, Json(serde_json::json!({
|
|
"access_token": access_token,
|
|
"refresh_token": new_refresh,
|
|
"user": {
|
|
"id": user_id,
|
|
"username": username,
|
|
"email": email,
|
|
"avatar_seed": avatar_seed,
|
|
}
|
|
}))).into_response()
|
|
}
|
|
_ => (StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "invalid or expired refresh token"}))).into_response(),
|
|
}
|
|
}
|
|
|
|
pub async fn logout(
|
|
State(state): State<Arc<AppState>>,
|
|
Json(req): Json<RefreshRequest>,
|
|
) -> impl IntoResponse {
|
|
use sha2::{Digest, Sha256};
|
|
let token_hash = format!("{:x}", Sha256::digest(req.refresh_token.as_bytes()));
|
|
let _ = sqlx::query("UPDATE refresh_tokens SET revoked_at = NOW() WHERE token_hash = $1")
|
|
.bind(&token_hash)
|
|
.execute(&state.db)
|
|
.await;
|
|
StatusCode::NO_CONTENT
|
|
}
|
|
|
|
async fn issue_refresh_token(user_id: Uuid, state: &AppState) -> Result<String> {
|
|
use rand::Rng;
|
|
use sha2::{Digest, Sha256};
|
|
let token: String = rand::thread_rng()
|
|
.sample_iter(&rand::distributions::Alphanumeric)
|
|
.take(64)
|
|
.map(char::from)
|
|
.collect();
|
|
let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
|
|
let expires_at = Utc::now() + Duration::days(30);
|
|
sqlx::query(
|
|
"INSERT INTO refresh_tokens (user_id, token_hash, expires_at) VALUES ($1, $2, $3)"
|
|
)
|
|
.bind(user_id)
|
|
.bind(&token_hash)
|
|
.bind(expires_at)
|
|
.execute(&state.db)
|
|
.await?;
|
|
Ok(token)
|
|
}
|