multi sessions
This commit is contained in:
parent
5f3a9473ff
commit
c74ebbe5fa
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2663,10 +2663,12 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"argon2",
|
||||
"atom_syndication",
|
||||
"base64 0.21.7",
|
||||
"bcrypt",
|
||||
"chrono",
|
||||
"clap",
|
||||
"feed-rs",
|
||||
"getrandom 0.2.15",
|
||||
"reqwest",
|
||||
"rocket",
|
||||
"rocket_db_pools",
|
||||
|
@ -20,3 +20,5 @@ feed-rs = "2.3.1"
|
||||
reqwest = { version = "0.12.12", features = ["json"] }
|
||||
tokio = "1.43.0"
|
||||
time = "0.3.37"
|
||||
base64 = "0.21"
|
||||
getrandom = "0.2"
|
||||
|
13
src/demo.rs
13
src/demo.rs
@ -1,6 +1,6 @@
|
||||
use chrono;
|
||||
use sqlx;
|
||||
use rocket::serde;
|
||||
use sqlx;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::feeds::Feed;
|
||||
@ -77,16 +77,17 @@ pub async fn setup_demo_data(pool: &sqlx::SqlitePool) {
|
||||
);
|
||||
acx.categorization = vec!["Substack".to_string()];
|
||||
|
||||
|
||||
let feeds = [bbc_news, xkcd, isidore, acx];
|
||||
|
||||
for feed in feeds {
|
||||
// TODO: This insert logic is substantially the same as Feed::write_to_database.
|
||||
// Should find a way to unify these two code paths to avoid duplication.
|
||||
let categorization_json = serde::json::to_value(feed.categorization).map_err(|e| {
|
||||
eprintln!("Failed to serialize categorization: {}", e);
|
||||
sqlx::Error::Decode(Box::new(e))
|
||||
}).unwrap();
|
||||
let categorization_json = serde::json::to_value(feed.categorization)
|
||||
.map_err(|e| {
|
||||
eprintln!("Failed to serialize categorization: {}", e);
|
||||
sqlx::Error::Decode(Box::new(e))
|
||||
})
|
||||
.unwrap();
|
||||
println!("{}", categorization_json);
|
||||
|
||||
sqlx::query(
|
||||
|
@ -101,7 +101,10 @@ fn rocket() -> _ {
|
||||
|
||||
let figment = rocket::Config::figment()
|
||||
.merge(("databases.rss_data.url", db_url))
|
||||
.merge(("secret_key", std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set")));
|
||||
.merge((
|
||||
"secret_key",
|
||||
std::env::var("SECRET_KEY").expect("SECRET_KEY environment variable must be set"),
|
||||
));
|
||||
|
||||
rocket::custom(figment)
|
||||
.mount(
|
||||
@ -128,6 +131,7 @@ fn rocket() -> _ {
|
||||
.attach(Template::fairing())
|
||||
.attach(Db::init())
|
||||
.manage(args.demo)
|
||||
.manage(user::SessionStore::new())
|
||||
.attach(AdHoc::try_on_ignite("DB Setup", move |rocket| async move {
|
||||
setup_database(args.demo, rocket).await
|
||||
}))
|
||||
|
79
src/user.rs
79
src/user.rs
@ -1,13 +1,57 @@
|
||||
use time::Duration;
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
||||
use rocket::http::{Cookie, CookieJar, Status};
|
||||
use rocket::serde::{json::Json, Deserialize, Serialize};
|
||||
use rocket::State;
|
||||
use rocket_db_pools::Connection;
|
||||
use rocket_dyn_templates::{context, Template};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Db;
|
||||
|
||||
pub struct SessionStore(RwLock<HashMap<Uuid, HashSet<String>>>);
|
||||
|
||||
impl SessionStore {
|
||||
pub fn new() -> Self {
|
||||
SessionStore(RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
fn generate_secret() -> String {
|
||||
let mut bytes = [0u8; 32];
|
||||
getrandom::getrandom(&mut bytes).expect("Failed to generate random bytes");
|
||||
BASE64.encode(bytes)
|
||||
}
|
||||
|
||||
fn store(&self, user_id: Uuid, secret: String) {
|
||||
let mut store = self.0.write().unwrap();
|
||||
store
|
||||
.entry(user_id)
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(secret);
|
||||
}
|
||||
|
||||
fn verify(&self, user_id: Uuid, secret: &str) -> bool {
|
||||
let store = self.0.read().unwrap();
|
||||
store
|
||||
.get(&user_id)
|
||||
.map_or(false, |secrets| secrets.contains(secret))
|
||||
}
|
||||
|
||||
fn remove(&self, user_id: Uuid, secret: &str) {
|
||||
let mut store = self.0.write().unwrap();
|
||||
if let Some(secrets) = store.get_mut(&user_id) {
|
||||
secrets.remove(secret);
|
||||
// Clean up the user entry if no sessions remain
|
||||
if secrets.is_empty() {
|
||||
store.remove(&user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub struct User {
|
||||
@ -184,6 +228,7 @@ pub async fn login(
|
||||
mut db: Connection<Db>,
|
||||
credentials: Json<LoginCredentials>,
|
||||
cookies: &CookieJar<'_>,
|
||||
sessions: &State<SessionStore>,
|
||||
) -> Result<Json<LoginResponse>, Status> {
|
||||
let creds = credentials.into_inner();
|
||||
|
||||
@ -219,13 +264,15 @@ pub async fn login(
|
||||
return Err(Status::Unauthorized);
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
// TODO tehre should be a more complicated notion of a session
|
||||
// Generate and store session
|
||||
let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?;
|
||||
let session_secret = SessionStore::generate_secret();
|
||||
sessions.store(user_id, session_secret.clone());
|
||||
|
||||
//TODO make this user-configurable
|
||||
// Set session cookie with both user_id and secret
|
||||
let max_age = Duration::days(6);
|
||||
let mut cookie = Cookie::new("user_id", user_id.to_string());
|
||||
let cookie_value = format!("{}:{}", user_id, session_secret);
|
||||
let mut cookie = Cookie::new("user_id", cookie_value);
|
||||
cookie.set_max_age(max_age);
|
||||
cookies.add_private(cookie);
|
||||
|
||||
@ -236,8 +283,15 @@ pub async fn login(
|
||||
}
|
||||
|
||||
#[post("/logout")]
|
||||
pub fn logout(cookies: &CookieJar<'_>) -> Status {
|
||||
cookies.remove_private(Cookie::build("user_id"));
|
||||
pub fn logout(cookies: &CookieJar<'_>, sessions: &State<SessionStore>) -> Status {
|
||||
if let Some(cookie) = cookies.get_private("user_id") {
|
||||
if let Some((user_id, secret)) = cookie.value().split_once(':') {
|
||||
if let Ok(user_id) = Uuid::parse_str(user_id) {
|
||||
sessions.remove(user_id, secret);
|
||||
}
|
||||
}
|
||||
cookies.remove_private(Cookie::build("user_id"));
|
||||
}
|
||||
Status::NoContent
|
||||
}
|
||||
|
||||
@ -255,13 +309,18 @@ impl<'r> rocket::request::FromRequest<'r> for AuthenticatedUser {
|
||||
) -> rocket::request::Outcome<Self, Self::Error> {
|
||||
use rocket::request::Outcome;
|
||||
|
||||
let sessions = request.rocket().state::<SessionStore>().unwrap();
|
||||
|
||||
match request.cookies().get_private("user_id") {
|
||||
Some(cookie) => {
|
||||
if let Ok(user_id) = Uuid::parse_str(cookie.value()) {
|
||||
Outcome::Success(AuthenticatedUser { user_id })
|
||||
} else {
|
||||
Outcome::Forward(Status::Unauthorized)
|
||||
if let Some((user_id, secret)) = cookie.value().split_once(':') {
|
||||
if let Ok(user_id) = Uuid::parse_str(user_id) {
|
||||
if sessions.verify(user_id, secret) {
|
||||
return Outcome::Success(AuthenticatedUser { user_id });
|
||||
}
|
||||
}
|
||||
}
|
||||
Outcome::Forward(Status::Unauthorized)
|
||||
}
|
||||
None => Outcome::Forward(Status::Unauthorized),
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user