multi sessions

This commit is contained in:
Greg Shuflin 2025-02-04 00:45:49 -08:00
parent 5f3a9473ff
commit c74ebbe5fa
5 changed files with 85 additions and 17 deletions

2
Cargo.lock generated
View File

@ -2663,10 +2663,12 @@ version = "0.1.0"
dependencies = [ dependencies = [
"argon2", "argon2",
"atom_syndication", "atom_syndication",
"base64 0.21.7",
"bcrypt", "bcrypt",
"chrono", "chrono",
"clap", "clap",
"feed-rs", "feed-rs",
"getrandom 0.2.15",
"reqwest", "reqwest",
"rocket", "rocket",
"rocket_db_pools", "rocket_db_pools",

View File

@ -20,3 +20,5 @@ feed-rs = "2.3.1"
reqwest = { version = "0.12.12", features = ["json"] } reqwest = { version = "0.12.12", features = ["json"] }
tokio = "1.43.0" tokio = "1.43.0"
time = "0.3.37" time = "0.3.37"
base64 = "0.21"
getrandom = "0.2"

View File

@ -1,6 +1,6 @@
use chrono; use chrono;
use sqlx;
use rocket::serde; use rocket::serde;
use sqlx;
use uuid::Uuid; use uuid::Uuid;
use crate::feeds::Feed; use crate::feeds::Feed;
@ -77,16 +77,17 @@ pub async fn setup_demo_data(pool: &sqlx::SqlitePool) {
); );
acx.categorization = vec!["Substack".to_string()]; acx.categorization = vec!["Substack".to_string()];
let feeds = [bbc_news, xkcd, isidore, acx]; let feeds = [bbc_news, xkcd, isidore, acx];
for feed in feeds { for feed in feeds {
// TODO: This insert logic is substantially the same as Feed::write_to_database. // TODO: This insert logic is substantially the same as Feed::write_to_database.
// Should find a way to unify these two code paths to avoid duplication. // Should find a way to unify these two code paths to avoid duplication.
let categorization_json = serde::json::to_value(feed.categorization).map_err(|e| { let categorization_json = serde::json::to_value(feed.categorization)
eprintln!("Failed to serialize categorization: {}", e); .map_err(|e| {
sqlx::Error::Decode(Box::new(e)) eprintln!("Failed to serialize categorization: {}", e);
}).unwrap(); sqlx::Error::Decode(Box::new(e))
})
.unwrap();
println!("{}", categorization_json); println!("{}", categorization_json);
sqlx::query( sqlx::query(

View File

@ -101,7 +101,10 @@ fn rocket() -> _ {
let figment = rocket::Config::figment() let figment = rocket::Config::figment()
.merge(("databases.rss_data.url", db_url)) .merge(("databases.rss_data.url", db_url))
.merge(("secret_key", std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set"))); .merge((
"secret_key",
std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set"),
));
rocket::custom(figment) rocket::custom(figment)
.mount( .mount(
@ -128,6 +131,7 @@ fn rocket() -> _ {
.attach(Template::fairing()) .attach(Template::fairing())
.attach(Db::init()) .attach(Db::init())
.manage(args.demo) .manage(args.demo)
.manage(user::SessionStore::new())
.attach(AdHoc::try_on_ignite("DB Setup", move |rocket| async move { .attach(AdHoc::try_on_ignite("DB Setup", move |rocket| async move {
setup_database(args.demo, rocket).await setup_database(args.demo, rocket).await
})) }))

View File

@ -1,13 +1,57 @@
use time::Duration; use time::Duration;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use rocket::http::{Cookie, CookieJar, Status}; use rocket::http::{Cookie, CookieJar, Status};
use rocket::serde::{json::Json, Deserialize, Serialize}; use rocket::serde::{json::Json, Deserialize, Serialize};
use rocket::State;
use rocket_db_pools::Connection; use rocket_db_pools::Connection;
use rocket_dyn_templates::{context, Template}; use rocket_dyn_templates::{context, Template};
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use uuid::Uuid; use uuid::Uuid;
use crate::Db; use crate::Db;
pub struct SessionStore(RwLock<HashMap<Uuid, HashSet<String>>>);
impl SessionStore {
pub fn new() -> Self {
SessionStore(RwLock::new(HashMap::new()))
}
fn generate_secret() -> String {
let mut bytes = [0u8; 32];
getrandom::getrandom(&mut bytes).expect("Failed to generate random bytes");
BASE64.encode(bytes)
}
fn store(&self, user_id: Uuid, secret: String) {
let mut store = self.0.write().unwrap();
store
.entry(user_id)
.or_insert_with(HashSet::new)
.insert(secret);
}
fn verify(&self, user_id: Uuid, secret: &str) -> bool {
let store = self.0.read().unwrap();
store
.get(&user_id)
.map_or(false, |secrets| secrets.contains(secret))
}
fn remove(&self, user_id: Uuid, secret: &str) {
let mut store = self.0.write().unwrap();
if let Some(secrets) = store.get_mut(&user_id) {
secrets.remove(secret);
// Clean up the user entry if no sessions remain
if secrets.is_empty() {
store.remove(&user_id);
}
}
}
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(crate = "rocket::serde")] #[serde(crate = "rocket::serde")]
pub struct User { pub struct User {
@ -184,6 +228,7 @@ pub async fn login(
mut db: Connection<Db>, mut db: Connection<Db>,
credentials: Json<LoginCredentials>, credentials: Json<LoginCredentials>,
cookies: &CookieJar<'_>, cookies: &CookieJar<'_>,
sessions: &State<SessionStore>,
) -> Result<Json<LoginResponse>, Status> { ) -> Result<Json<LoginResponse>, Status> {
let creds = credentials.into_inner(); let creds = credentials.into_inner();
@ -219,13 +264,15 @@ pub async fn login(
return Err(Status::Unauthorized); return Err(Status::Unauthorized);
} }
// Set session cookie // Generate and store session
// TODO tehre should be a more complicated notion of a session
let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?; let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?;
let session_secret = SessionStore::generate_secret();
sessions.store(user_id, session_secret.clone());
//TODO make this user-configurable // Set session cookie with both user_id and secret
let max_age = Duration::days(6); let max_age = Duration::days(6);
let mut cookie = Cookie::new("user_id", user_id.to_string()); let cookie_value = format!("{}:{}", user_id, session_secret);
let mut cookie = Cookie::new("user_id", cookie_value);
cookie.set_max_age(max_age); cookie.set_max_age(max_age);
cookies.add_private(cookie); cookies.add_private(cookie);
@ -236,8 +283,15 @@ pub async fn login(
} }
#[post("/logout")] #[post("/logout")]
pub fn logout(cookies: &CookieJar<'_>) -> Status { pub fn logout(cookies: &CookieJar<'_>, sessions: &State<SessionStore>) -> Status {
cookies.remove_private(Cookie::build("user_id")); if let Some(cookie) = cookies.get_private("user_id") {
if let Some((user_id, secret)) = cookie.value().split_once(':') {
if let Ok(user_id) = Uuid::parse_str(user_id) {
sessions.remove(user_id, secret);
}
}
cookies.remove_private(Cookie::build("user_id"));
}
Status::NoContent Status::NoContent
} }
@ -255,13 +309,18 @@ impl<'r> rocket::request::FromRequest<'r> for AuthenticatedUser {
) -> rocket::request::Outcome<Self, Self::Error> { ) -> rocket::request::Outcome<Self, Self::Error> {
use rocket::request::Outcome; use rocket::request::Outcome;
let sessions = request.rocket().state::<SessionStore>().unwrap();
match request.cookies().get_private("user_id") { match request.cookies().get_private("user_id") {
Some(cookie) => { Some(cookie) => {
if let Ok(user_id) = Uuid::parse_str(cookie.value()) { if let Some((user_id, secret)) = cookie.value().split_once(':') {
Outcome::Success(AuthenticatedUser { user_id }) if let Ok(user_id) = Uuid::parse_str(user_id) {
} else { if sessions.verify(user_id, secret) {
Outcome::Forward(Status::Unauthorized) return Outcome::Success(AuthenticatedUser { user_id });
}
}
} }
Outcome::Forward(Status::Unauthorized)
} }
None => Outcome::Forward(Status::Unauthorized), None => Outcome::Forward(Status::Unauthorized),
} }