From c74ebbe5fa73da24e1aba094f6cd0ef72a25fa56 Mon Sep 17 00:00:00 2001 From: Greg Shuflin Date: Tue, 4 Feb 2025 00:45:49 -0800 Subject: [PATCH] multi sessions --- Cargo.lock | 2 ++ Cargo.toml | 2 ++ src/demo.rs | 13 +++++---- src/main.rs | 6 +++- src/user.rs | 79 ++++++++++++++++++++++++++++++++++++++++++++++------- 5 files changed, 85 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a4b1b06..2f796ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2663,10 +2663,12 @@ version = "0.1.0" dependencies = [ "argon2", "atom_syndication", + "base64 0.21.7", "bcrypt", "chrono", "clap", "feed-rs", + "getrandom 0.2.15", "reqwest", "rocket", "rocket_db_pools", diff --git a/Cargo.toml b/Cargo.toml index 2606023..c7b92c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,3 +20,5 @@ feed-rs = "2.3.1" reqwest = { version = "0.12.12", features = ["json"] } tokio = "1.43.0" time = "0.3.37" +base64 = "0.21" +getrandom = "0.2" diff --git a/src/demo.rs b/src/demo.rs index dc2cae1..cbdcff3 100644 --- a/src/demo.rs +++ b/src/demo.rs @@ -1,6 +1,6 @@ use chrono; -use sqlx; use rocket::serde; +use sqlx; use uuid::Uuid; use crate::feeds::Feed; @@ -77,16 +77,17 @@ pub async fn setup_demo_data(pool: &sqlx::SqlitePool) { ); acx.categorization = vec!["Substack".to_string()]; - let feeds = [bbc_news, xkcd, isidore, acx]; for feed in feeds { // TODO: This insert logic is substantially the same as Feed::write_to_database. // Should find a way to unify these two code paths to avoid duplication. - let categorization_json = serde::json::to_value(feed.categorization).map_err(|e| { - eprintln!("Failed to serialize categorization: {}", e); - sqlx::Error::Decode(Box::new(e)) - }).unwrap(); + let categorization_json = serde::json::to_value(feed.categorization) + .map_err(|e| { + eprintln!("Failed to serialize categorization: {}", e); + sqlx::Error::Decode(Box::new(e)) + }) + .unwrap(); println!("{}", categorization_json); sqlx::query( diff --git a/src/main.rs b/src/main.rs index 28a90ff..265b453 100644 --- a/src/main.rs +++ b/src/main.rs @@ -101,7 +101,10 @@ fn rocket() -> _ { let figment = rocket::Config::figment() .merge(("databases.rss_data.url", db_url)) - .merge(("secret_key", std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set"))); + .merge(( + "secret_key", + std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set"), + )); rocket::custom(figment) .mount( @@ -128,6 +131,7 @@ fn rocket() -> _ { .attach(Template::fairing()) .attach(Db::init()) .manage(args.demo) + .manage(user::SessionStore::new()) .attach(AdHoc::try_on_ignite("DB Setup", move |rocket| async move { setup_database(args.demo, rocket).await })) diff --git a/src/user.rs b/src/user.rs index a988c9a..70a6dad 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,13 +1,57 @@ use time::Duration; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use rocket::http::{Cookie, CookieJar, Status}; use rocket::serde::{json::Json, Deserialize, Serialize}; +use rocket::State; use rocket_db_pools::Connection; use rocket_dyn_templates::{context, Template}; +use std::collections::{HashMap, HashSet}; +use std::sync::RwLock; use uuid::Uuid; use crate::Db; +pub struct SessionStore(RwLock>>); + +impl SessionStore { + pub fn new() -> Self { + SessionStore(RwLock::new(HashMap::new())) + } + + fn generate_secret() -> String { + let mut bytes = [0u8; 32]; + getrandom::getrandom(&mut bytes).expect("Failed to generate random bytes"); + BASE64.encode(bytes) + } + + fn store(&self, user_id: Uuid, secret: String) { + let mut store = self.0.write().unwrap(); + store + .entry(user_id) + .or_insert_with(HashSet::new) + .insert(secret); + } + + fn verify(&self, user_id: Uuid, secret: &str) -> bool { + let store = self.0.read().unwrap(); + store + .get(&user_id) + .map_or(false, |secrets| secrets.contains(secret)) + } + + fn remove(&self, user_id: Uuid, secret: &str) { + let mut store = self.0.write().unwrap(); + if let Some(secrets) = store.get_mut(&user_id) { + secrets.remove(secret); + // Clean up the user entry if no sessions remain + if secrets.is_empty() { + store.remove(&user_id); + } + } + } +} + #[derive(Debug, Serialize)] #[serde(crate = "rocket::serde")] pub struct User { @@ -184,6 +228,7 @@ pub async fn login( mut db: Connection, credentials: Json, cookies: &CookieJar<'_>, + sessions: &State, ) -> Result, Status> { let creds = credentials.into_inner(); @@ -219,13 +264,15 @@ pub async fn login( return Err(Status::Unauthorized); } - // Set session cookie - // TODO tehre should be a more complicated notion of a session + // Generate and store session let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?; + let session_secret = SessionStore::generate_secret(); + sessions.store(user_id, session_secret.clone()); - //TODO make this user-configurable + // Set session cookie with both user_id and secret let max_age = Duration::days(6); - let mut cookie = Cookie::new("user_id", user_id.to_string()); + let cookie_value = format!("{}:{}", user_id, session_secret); + let mut cookie = Cookie::new("user_id", cookie_value); cookie.set_max_age(max_age); cookies.add_private(cookie); @@ -236,8 +283,15 @@ pub async fn login( } #[post("/logout")] -pub fn logout(cookies: &CookieJar<'_>) -> Status { - cookies.remove_private(Cookie::build("user_id")); +pub fn logout(cookies: &CookieJar<'_>, sessions: &State) -> Status { + if let Some(cookie) = cookies.get_private("user_id") { + if let Some((user_id, secret)) = cookie.value().split_once(':') { + if let Ok(user_id) = Uuid::parse_str(user_id) { + sessions.remove(user_id, secret); + } + } + cookies.remove_private(Cookie::build("user_id")); + } Status::NoContent } @@ -255,13 +309,18 @@ impl<'r> rocket::request::FromRequest<'r> for AuthenticatedUser { ) -> rocket::request::Outcome { use rocket::request::Outcome; + let sessions = request.rocket().state::().unwrap(); + match request.cookies().get_private("user_id") { Some(cookie) => { - if let Ok(user_id) = Uuid::parse_str(cookie.value()) { - Outcome::Success(AuthenticatedUser { user_id }) - } else { - Outcome::Forward(Status::Unauthorized) + if let Some((user_id, secret)) = cookie.value().split_once(':') { + if let Ok(user_id) = Uuid::parse_str(user_id) { + if sessions.verify(user_id, secret) { + return Outcome::Success(AuthenticatedUser { user_id }); + } + } } + Outcome::Forward(Status::Unauthorized) } None => Outcome::Forward(Status::Unauthorized), }