multi sessions
This commit is contained in:
parent
5f3a9473ff
commit
c74ebbe5fa
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2663,10 +2663,12 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"argon2",
|
"argon2",
|
||||||
"atom_syndication",
|
"atom_syndication",
|
||||||
|
"base64 0.21.7",
|
||||||
"bcrypt",
|
"bcrypt",
|
||||||
"chrono",
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
"feed-rs",
|
"feed-rs",
|
||||||
|
"getrandom 0.2.15",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rocket",
|
"rocket",
|
||||||
"rocket_db_pools",
|
"rocket_db_pools",
|
||||||
|
@ -20,3 +20,5 @@ feed-rs = "2.3.1"
|
|||||||
reqwest = { version = "0.12.12", features = ["json"] }
|
reqwest = { version = "0.12.12", features = ["json"] }
|
||||||
tokio = "1.43.0"
|
tokio = "1.43.0"
|
||||||
time = "0.3.37"
|
time = "0.3.37"
|
||||||
|
base64 = "0.21"
|
||||||
|
getrandom = "0.2"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use chrono;
|
use chrono;
|
||||||
use sqlx;
|
|
||||||
use rocket::serde;
|
use rocket::serde;
|
||||||
|
use sqlx;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::feeds::Feed;
|
use crate::feeds::Feed;
|
||||||
@ -77,16 +77,17 @@ pub async fn setup_demo_data(pool: &sqlx::SqlitePool) {
|
|||||||
);
|
);
|
||||||
acx.categorization = vec!["Substack".to_string()];
|
acx.categorization = vec!["Substack".to_string()];
|
||||||
|
|
||||||
|
|
||||||
let feeds = [bbc_news, xkcd, isidore, acx];
|
let feeds = [bbc_news, xkcd, isidore, acx];
|
||||||
|
|
||||||
for feed in feeds {
|
for feed in feeds {
|
||||||
// TODO: This insert logic is substantially the same as Feed::write_to_database.
|
// TODO: This insert logic is substantially the same as Feed::write_to_database.
|
||||||
// Should find a way to unify these two code paths to avoid duplication.
|
// Should find a way to unify these two code paths to avoid duplication.
|
||||||
let categorization_json = serde::json::to_value(feed.categorization).map_err(|e| {
|
let categorization_json = serde::json::to_value(feed.categorization)
|
||||||
|
.map_err(|e| {
|
||||||
eprintln!("Failed to serialize categorization: {}", e);
|
eprintln!("Failed to serialize categorization: {}", e);
|
||||||
sqlx::Error::Decode(Box::new(e))
|
sqlx::Error::Decode(Box::new(e))
|
||||||
}).unwrap();
|
})
|
||||||
|
.unwrap();
|
||||||
println!("{}", categorization_json);
|
println!("{}", categorization_json);
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
|
@ -101,7 +101,10 @@ fn rocket() -> _ {
|
|||||||
|
|
||||||
let figment = rocket::Config::figment()
|
let figment = rocket::Config::figment()
|
||||||
.merge(("databases.rss_data.url", db_url))
|
.merge(("databases.rss_data.url", db_url))
|
||||||
.merge(("secret_key", std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set")));
|
.merge((
|
||||||
|
"secret_key",
|
||||||
|
std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set"),
|
||||||
|
));
|
||||||
|
|
||||||
rocket::custom(figment)
|
rocket::custom(figment)
|
||||||
.mount(
|
.mount(
|
||||||
@ -128,6 +131,7 @@ fn rocket() -> _ {
|
|||||||
.attach(Template::fairing())
|
.attach(Template::fairing())
|
||||||
.attach(Db::init())
|
.attach(Db::init())
|
||||||
.manage(args.demo)
|
.manage(args.demo)
|
||||||
|
.manage(user::SessionStore::new())
|
||||||
.attach(AdHoc::try_on_ignite("DB Setup", move |rocket| async move {
|
.attach(AdHoc::try_on_ignite("DB Setup", move |rocket| async move {
|
||||||
setup_database(args.demo, rocket).await
|
setup_database(args.demo, rocket).await
|
||||||
}))
|
}))
|
||||||
|
77
src/user.rs
77
src/user.rs
@ -1,13 +1,57 @@
|
|||||||
use time::Duration;
|
use time::Duration;
|
||||||
|
|
||||||
|
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
||||||
use rocket::http::{Cookie, CookieJar, Status};
|
use rocket::http::{Cookie, CookieJar, Status};
|
||||||
use rocket::serde::{json::Json, Deserialize, Serialize};
|
use rocket::serde::{json::Json, Deserialize, Serialize};
|
||||||
|
use rocket::State;
|
||||||
use rocket_db_pools::Connection;
|
use rocket_db_pools::Connection;
|
||||||
use rocket_dyn_templates::{context, Template};
|
use rocket_dyn_templates::{context, Template};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::sync::RwLock;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::Db;
|
use crate::Db;
|
||||||
|
|
||||||
|
pub struct SessionStore(RwLock<HashMap<Uuid, HashSet<String>>>);
|
||||||
|
|
||||||
|
impl SessionStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
SessionStore(RwLock::new(HashMap::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_secret() -> String {
|
||||||
|
let mut bytes = [0u8; 32];
|
||||||
|
getrandom::getrandom(&mut bytes).expect("Failed to generate random bytes");
|
||||||
|
BASE64.encode(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store(&self, user_id: Uuid, secret: String) {
|
||||||
|
let mut store = self.0.write().unwrap();
|
||||||
|
store
|
||||||
|
.entry(user_id)
|
||||||
|
.or_insert_with(HashSet::new)
|
||||||
|
.insert(secret);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify(&self, user_id: Uuid, secret: &str) -> bool {
|
||||||
|
let store = self.0.read().unwrap();
|
||||||
|
store
|
||||||
|
.get(&user_id)
|
||||||
|
.map_or(false, |secrets| secrets.contains(secret))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove(&self, user_id: Uuid, secret: &str) {
|
||||||
|
let mut store = self.0.write().unwrap();
|
||||||
|
if let Some(secrets) = store.get_mut(&user_id) {
|
||||||
|
secrets.remove(secret);
|
||||||
|
// Clean up the user entry if no sessions remain
|
||||||
|
if secrets.is_empty() {
|
||||||
|
store.remove(&user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(crate = "rocket::serde")]
|
#[serde(crate = "rocket::serde")]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
@ -184,6 +228,7 @@ pub async fn login(
|
|||||||
mut db: Connection<Db>,
|
mut db: Connection<Db>,
|
||||||
credentials: Json<LoginCredentials>,
|
credentials: Json<LoginCredentials>,
|
||||||
cookies: &CookieJar<'_>,
|
cookies: &CookieJar<'_>,
|
||||||
|
sessions: &State<SessionStore>,
|
||||||
) -> Result<Json<LoginResponse>, Status> {
|
) -> Result<Json<LoginResponse>, Status> {
|
||||||
let creds = credentials.into_inner();
|
let creds = credentials.into_inner();
|
||||||
|
|
||||||
@ -219,13 +264,15 @@ pub async fn login(
|
|||||||
return Err(Status::Unauthorized);
|
return Err(Status::Unauthorized);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set session cookie
|
// Generate and store session
|
||||||
// TODO tehre should be a more complicated notion of a session
|
|
||||||
let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?;
|
let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?;
|
||||||
|
let session_secret = SessionStore::generate_secret();
|
||||||
|
sessions.store(user_id, session_secret.clone());
|
||||||
|
|
||||||
//TODO make this user-configurable
|
// Set session cookie with both user_id and secret
|
||||||
let max_age = Duration::days(6);
|
let max_age = Duration::days(6);
|
||||||
let mut cookie = Cookie::new("user_id", user_id.to_string());
|
let cookie_value = format!("{}:{}", user_id, session_secret);
|
||||||
|
let mut cookie = Cookie::new("user_id", cookie_value);
|
||||||
cookie.set_max_age(max_age);
|
cookie.set_max_age(max_age);
|
||||||
cookies.add_private(cookie);
|
cookies.add_private(cookie);
|
||||||
|
|
||||||
@ -236,8 +283,15 @@ pub async fn login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[post("/logout")]
|
#[post("/logout")]
|
||||||
pub fn logout(cookies: &CookieJar<'_>) -> Status {
|
pub fn logout(cookies: &CookieJar<'_>, sessions: &State<SessionStore>) -> Status {
|
||||||
|
if let Some(cookie) = cookies.get_private("user_id") {
|
||||||
|
if let Some((user_id, secret)) = cookie.value().split_once(':') {
|
||||||
|
if let Ok(user_id) = Uuid::parse_str(user_id) {
|
||||||
|
sessions.remove(user_id, secret);
|
||||||
|
}
|
||||||
|
}
|
||||||
cookies.remove_private(Cookie::build("user_id"));
|
cookies.remove_private(Cookie::build("user_id"));
|
||||||
|
}
|
||||||
Status::NoContent
|
Status::NoContent
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -255,14 +309,19 @@ impl<'r> rocket::request::FromRequest<'r> for AuthenticatedUser {
|
|||||||
) -> rocket::request::Outcome<Self, Self::Error> {
|
) -> rocket::request::Outcome<Self, Self::Error> {
|
||||||
use rocket::request::Outcome;
|
use rocket::request::Outcome;
|
||||||
|
|
||||||
|
let sessions = request.rocket().state::<SessionStore>().unwrap();
|
||||||
|
|
||||||
match request.cookies().get_private("user_id") {
|
match request.cookies().get_private("user_id") {
|
||||||
Some(cookie) => {
|
Some(cookie) => {
|
||||||
if let Ok(user_id) = Uuid::parse_str(cookie.value()) {
|
if let Some((user_id, secret)) = cookie.value().split_once(':') {
|
||||||
Outcome::Success(AuthenticatedUser { user_id })
|
if let Ok(user_id) = Uuid::parse_str(user_id) {
|
||||||
} else {
|
if sessions.verify(user_id, secret) {
|
||||||
Outcome::Forward(Status::Unauthorized)
|
return Outcome::Success(AuthenticatedUser { user_id });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
Outcome::Forward(Status::Unauthorized)
|
||||||
|
}
|
||||||
None => Outcome::Forward(Status::Unauthorized),
|
None => Outcome::Forward(Status::Unauthorized),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user