use time::Duration; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use rocket::http::{Cookie, CookieJar, Status}; use rocket::serde::{json::Json, Deserialize, Serialize}; use rocket::State; use rocket_db_pools::Connection; use rocket_dyn_templates::{context, Template}; use std::collections::{HashMap, HashSet}; use std::sync::RwLock; use uuid::Uuid; use crate::Db; pub struct SessionStore(RwLock>>); impl SessionStore { pub fn new() -> Self { SessionStore(RwLock::new(HashMap::new())) } fn generate_secret() -> String { let mut bytes = [0u8; 32]; getrandom::getrandom(&mut bytes).expect("Failed to generate random bytes"); BASE64.encode(bytes) } fn store(&self, user_id: Uuid, secret: String) { let mut store = self.0.write().unwrap(); store .entry(user_id) .or_insert_with(HashSet::new) .insert(secret); } fn verify(&self, user_id: Uuid, secret: &str) -> bool { let store = self.0.read().unwrap(); store .get(&user_id) .map_or(false, |secrets| secrets.contains(secret)) } fn remove(&self, user_id: Uuid, secret: &str) { let mut store = self.0.write().unwrap(); if let Some(secrets) = store.get_mut(&user_id) { secrets.remove(secret); // Clean up the user entry if no sessions remain if secrets.is_empty() { store.remove(&user_id); } } } } #[derive(Debug, Serialize)] #[serde(crate = "rocket::serde")] pub struct User { pub id: Uuid, pub username: String, pub password_hash: String, pub email: Option, pub display_name: Option, pub created_at: chrono::DateTime, pub admin: bool, } impl User { pub fn new( username: String, password_hash: String, email: Option, display_name: Option, ) -> Self { User { id: Uuid::new_v4(), username, password_hash, email, display_name, created_at: chrono::Utc::now(), admin: false, } } pub async fn write_to_database<'a, E>(&self, executor: E) -> sqlx::Result<()> where E: sqlx::Executor<'a, Database = sqlx::Sqlite>, { sqlx::query( r#" INSERT INTO users (id, username, password_hash, email, display_name, created_at, admin) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) "#, ) .bind(self.id.to_string()) .bind(self.username.clone()) .bind(self.password_hash.clone()) .bind(self.email.clone()) .bind(self.display_name.clone()) .bind(self.created_at.to_rfc3339()) .bind(self.admin) .execute(executor) .await?; Ok(()) } } #[derive(Debug, Deserialize)] #[serde(crate = "rocket::serde")] pub struct NewUser { username: String, password: String, email: Option, display_name: Option, } #[derive(Debug, Deserialize)] #[serde(crate = "rocket::serde")] pub struct LoginCredentials { username: String, password: String, } #[derive(Debug, Serialize)] #[serde(crate = "rocket::serde")] pub struct LoginResponse { user_id: Uuid, username: String, } #[derive(Debug, Serialize)] #[serde(crate = "rocket::serde")] pub struct SetupError { error: String, } #[post("/users", data = "")] pub async fn create_user( mut db: Connection, new_user: Json, ) -> Result, Status> { let new_user = new_user.into_inner(); // Hash the password - we'll use bcrypt let password_hash = bcrypt::hash(new_user.password.as_bytes(), bcrypt::DEFAULT_COST) .map_err(|_| Status::InternalServerError)?; let user = User::new( new_user.username, password_hash, new_user.email, new_user.display_name, ); match user.write_to_database(&mut **db).await { Ok(_) => Ok(Json(user)), Err(e) => { eprintln!("Database error: {}", e); match e { sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { Err(Status::Conflict) } _ => Err(Status::InternalServerError), } } } } #[get("/users")] pub async fn get_users(mut db: Connection) -> Result>, Status> { let query = sqlx::query!( r#" SELECT id as "id: String", username, password_hash, email, display_name, created_at as "created_at: chrono::DateTime", admin as "admin: bool" FROM users "# ) .fetch_all(&mut **db) .await .map_err(|e| { eprintln!("Database error: {}", e); Status::InternalServerError })?; // Convert the strings to UUIDs let users = query .into_iter() .map(|row| User { id: Uuid::parse_str(&row.id).unwrap(), username: row.username, password_hash: row.password_hash, email: row.email, display_name: row.display_name, created_at: row.created_at, admin: row.admin, }) .collect::>(); Ok(Json(users)) } #[delete("/users/")] pub async fn delete_user(mut db: Connection, user_id: &str) -> Status { // Validate UUID format let uuid = match Uuid::parse_str(user_id) { Ok(uuid) => uuid, Err(_) => return Status::BadRequest, }; let query = sqlx::query("DELETE FROM users WHERE id = ?") .bind(uuid.to_string()) .execute(&mut **db) .await; match query { Ok(result) => { if result.rows_affected() > 0 { Status::NoContent } else { Status::NotFound } } Err(e) => { eprintln!("Database error: {}", e); Status::InternalServerError } } } #[post("/login", data = "")] pub async fn login( mut db: Connection, credentials: Json, cookies: &CookieJar<'_>, sessions: &State, ) -> Result, Status> { let creds = credentials.into_inner(); // Find user by username let user = sqlx::query!( r#" SELECT id as "id: String", username, password_hash FROM users WHERE username = ? "#, creds.username ) .fetch_optional(&mut **db) .await .map_err(|e| { eprintln!("Database error: {}", e); Status::InternalServerError })?; let user = match user { Some(user) => user, None => return Err(Status::Unauthorized), }; // Verify password let valid = bcrypt::verify(creds.password.as_bytes(), &user.password_hash) .map_err(|_| Status::InternalServerError)?; if !valid { return Err(Status::Unauthorized); } // Generate and store session let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?; let session_secret = SessionStore::generate_secret(); sessions.store(user_id, session_secret.clone()); // Set session cookie with both user_id and secret let max_age = Duration::days(6); let cookie_value = format!("{}:{}", user_id, session_secret); let mut cookie = Cookie::new("user_id", cookie_value); cookie.set_max_age(max_age); cookies.add_private(cookie); Ok(Json(LoginResponse { user_id, username: user.username, })) } #[post("/logout")] pub fn logout(cookies: &CookieJar<'_>, sessions: &State) -> Status { if let Some(cookie) = cookies.get_private("user_id") { if let Some((user_id, secret)) = cookie.value().split_once(':') { if let Ok(user_id) = Uuid::parse_str(user_id) { sessions.remove(user_id, secret); } } cookies.remove_private(Cookie::build("user_id")); } Status::NoContent } // Add auth guard pub struct AuthenticatedUser { pub user_id: Uuid, } #[rocket::async_trait] impl<'r> rocket::request::FromRequest<'r> for AuthenticatedUser { type Error = (); async fn from_request( request: &'r rocket::Request<'_>, ) -> rocket::request::Outcome { use rocket::request::Outcome; let sessions = request.rocket().state::().unwrap(); match request.cookies().get_private("user_id") { Some(cookie) => { if let Some((user_id, secret)) = cookie.value().split_once(':') { if let Ok(user_id) = Uuid::parse_str(user_id) { if sessions.verify(user_id, secret) { return Outcome::Success(AuthenticatedUser { user_id }); } } } Outcome::Forward(Status::Unauthorized) } None => Outcome::Forward(Status::Unauthorized), } } } #[get("/setup")] pub async fn setup_page(mut db: Connection) -> Result { // Check if any users exist let count = sqlx::query!("SELECT COUNT(*) as count FROM users") .fetch_one(&mut **db) .await .map_err(|_| Status::InternalServerError)? .count; if count > 0 { // If users exist, redirect to login Err(Status::SeeOther) } else { // Show setup page Ok(Template::render("setup", context! {})) } } #[post("/setup", data = "")] pub async fn setup( mut db: Connection, new_user: Json, ) -> Result> { let new_user = new_user.into_inner(); // Check if any users exist let count = sqlx::query!("SELECT COUNT(*) as count FROM users") .fetch_one(&mut **db) .await .map_err(|e| { eprintln!("Database error: {}", e); Json(SetupError { error: "Internal server error".to_string(), }) })? .count; if count > 0 { return Err(Json(SetupError { error: "Setup has already been completed".to_string(), })); } let password = new_user.password.as_bytes(); // Hash the password let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST).map_err(|e| { eprintln!("Password hashing error: {}", e); Json(SetupError { error: "Failed to process password".to_string(), }) })?; let mut user = User::new( new_user.username, password_hash, new_user.email, new_user.display_name, ); user.admin = true; // This is an admin user match user.write_to_database(&mut **db).await { Ok(_) => Ok(Status::Created), Err(e) => { eprintln!("Database error: {}", e); match e { sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { Err(Json(SetupError { error: "Username already exists".to_string(), })) } _ => Err(Json(SetupError { error: "Failed to create user".to_string(), })), } } } }