377 lines
10 KiB
Rust
377 lines
10 KiB
Rust
use time::Duration;
|
|
|
|
use rocket::http::{Cookie, CookieJar, Status};
|
|
use rocket::serde::{json::Json, Deserialize, Serialize};
|
|
use rocket::State;
|
|
use rocket_db_pools::Connection;
|
|
use rocket_dyn_templates::{context, Template};
|
|
use uuid::Uuid;
|
|
|
|
use crate::session_store::SessionStore;
|
|
use crate::Db;
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(crate = "rocket::serde")]
|
|
pub struct User {
|
|
pub id: Uuid,
|
|
pub username: String,
|
|
pub password_hash: String,
|
|
pub email: Option<String>,
|
|
pub display_name: Option<String>,
|
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
|
pub admin: bool,
|
|
}
|
|
|
|
impl User {
|
|
pub fn new(
|
|
username: String,
|
|
password_hash: String,
|
|
email: Option<String>,
|
|
display_name: Option<String>,
|
|
) -> Self {
|
|
User {
|
|
id: Uuid::new_v4(),
|
|
username,
|
|
password_hash,
|
|
email,
|
|
display_name,
|
|
created_at: chrono::Utc::now(),
|
|
admin: false,
|
|
}
|
|
}
|
|
|
|
pub async fn write_to_database<'a, E>(&self, executor: E) -> sqlx::Result<()>
|
|
where
|
|
E: sqlx::Executor<'a, Database = sqlx::Sqlite>,
|
|
{
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO users (id, username, password_hash, email, display_name, created_at, admin)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
|
"#,
|
|
)
|
|
.bind(self.id.to_string())
|
|
.bind(self.username.clone())
|
|
.bind(self.password_hash.clone())
|
|
.bind(self.email.clone())
|
|
.bind(self.display_name.clone())
|
|
.bind(self.created_at.to_rfc3339())
|
|
.bind(self.admin)
|
|
.execute(executor)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(crate = "rocket::serde")]
|
|
pub struct NewUser {
|
|
username: String,
|
|
password: String,
|
|
email: Option<String>,
|
|
display_name: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(crate = "rocket::serde")]
|
|
pub struct LoginCredentials {
|
|
username: String,
|
|
password: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(crate = "rocket::serde")]
|
|
pub struct LoginResponse {
|
|
user_id: Uuid,
|
|
username: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(crate = "rocket::serde")]
|
|
pub struct SetupError {
|
|
error: String,
|
|
}
|
|
|
|
#[post("/users", data = "<new_user>")]
|
|
pub async fn create_user(
|
|
mut db: Connection<Db>,
|
|
new_user: Json<NewUser>,
|
|
) -> Result<Json<User>, Status> {
|
|
let new_user = new_user.into_inner();
|
|
|
|
// Hash the password - we'll use bcrypt
|
|
let password_hash = bcrypt::hash(new_user.password.as_bytes(), bcrypt::DEFAULT_COST)
|
|
.map_err(|_| Status::InternalServerError)?;
|
|
|
|
let user = User::new(
|
|
new_user.username,
|
|
password_hash,
|
|
new_user.email,
|
|
new_user.display_name,
|
|
);
|
|
|
|
match user.write_to_database(&mut **db).await {
|
|
Ok(_) => Ok(Json(user)),
|
|
Err(e) => {
|
|
eprintln!("Database error: {}", e);
|
|
match e {
|
|
sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
|
|
Err(Status::Conflict)
|
|
}
|
|
_ => Err(Status::InternalServerError),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[get("/users")]
|
|
pub async fn get_users(mut db: Connection<Db>) -> Result<Json<Vec<User>>, Status> {
|
|
let query = sqlx::query!(
|
|
r#"
|
|
SELECT
|
|
id as "id: String",
|
|
username,
|
|
password_hash,
|
|
email,
|
|
display_name,
|
|
created_at as "created_at: chrono::DateTime<chrono::Utc>",
|
|
admin as "admin: bool"
|
|
FROM users
|
|
"#
|
|
)
|
|
.fetch_all(&mut **db)
|
|
.await
|
|
.map_err(|e| {
|
|
eprintln!("Database error: {}", e);
|
|
Status::InternalServerError
|
|
})?;
|
|
|
|
// Convert the strings to UUIDs
|
|
let users = query
|
|
.into_iter()
|
|
.map(|row| User {
|
|
id: Uuid::parse_str(&row.id).unwrap(),
|
|
username: row.username,
|
|
password_hash: row.password_hash,
|
|
email: row.email,
|
|
display_name: row.display_name,
|
|
created_at: row.created_at,
|
|
admin: row.admin,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
Ok(Json(users))
|
|
}
|
|
|
|
#[delete("/users/<user_id>")]
|
|
pub async fn delete_user(mut db: Connection<Db>, user_id: &str) -> Status {
|
|
// Validate UUID format
|
|
let uuid = match Uuid::parse_str(user_id) {
|
|
Ok(uuid) => uuid,
|
|
Err(_) => return Status::BadRequest,
|
|
};
|
|
|
|
let query = sqlx::query("DELETE FROM users WHERE id = ?")
|
|
.bind(uuid.to_string())
|
|
.execute(&mut **db)
|
|
.await;
|
|
|
|
match query {
|
|
Ok(result) => {
|
|
if result.rows_affected() > 0 {
|
|
Status::NoContent
|
|
} else {
|
|
Status::NotFound
|
|
}
|
|
}
|
|
Err(e) => {
|
|
eprintln!("Database error: {}", e);
|
|
Status::InternalServerError
|
|
}
|
|
}
|
|
}
|
|
|
|
#[post("/login", data = "<credentials>")]
|
|
pub async fn login(
|
|
mut db: Connection<Db>,
|
|
credentials: Json<LoginCredentials>,
|
|
cookies: &CookieJar<'_>,
|
|
sessions: &State<SessionStore>,
|
|
) -> Result<Json<LoginResponse>, Status> {
|
|
let creds = credentials.into_inner();
|
|
|
|
// Find user by username
|
|
let user = sqlx::query!(
|
|
r#"
|
|
SELECT
|
|
id as "id: String",
|
|
username,
|
|
password_hash
|
|
FROM users
|
|
WHERE username = ?
|
|
"#,
|
|
creds.username
|
|
)
|
|
.fetch_optional(&mut **db)
|
|
.await
|
|
.map_err(|e| {
|
|
eprintln!("Database error: {}", e);
|
|
Status::InternalServerError
|
|
})?;
|
|
|
|
let user = match user {
|
|
Some(user) => user,
|
|
None => return Err(Status::Unauthorized),
|
|
};
|
|
|
|
// Verify password
|
|
let valid = bcrypt::verify(creds.password.as_bytes(), &user.password_hash)
|
|
.map_err(|_| Status::InternalServerError)?;
|
|
|
|
if !valid {
|
|
return Err(Status::Unauthorized);
|
|
}
|
|
|
|
// Generate and store session
|
|
let user_id = Uuid::parse_str(&user.id).map_err(|_| Status::InternalServerError)?;
|
|
let session_secret = SessionStore::generate_secret();
|
|
sessions.store(user_id, session_secret.clone());
|
|
|
|
// Set session cookie with both user_id and secret
|
|
let max_age = Duration::days(6);
|
|
let cookie_value = format!("{}:{}", user_id, session_secret);
|
|
let mut cookie = Cookie::new("user_id", cookie_value);
|
|
cookie.set_max_age(max_age);
|
|
cookies.add_private(cookie);
|
|
|
|
Ok(Json(LoginResponse {
|
|
user_id,
|
|
username: user.username,
|
|
}))
|
|
}
|
|
|
|
#[post("/logout")]
|
|
pub fn logout(cookies: &CookieJar<'_>, sessions: &State<SessionStore>) -> Status {
|
|
if let Some(cookie) = cookies.get_private("user_id") {
|
|
if let Some((user_id, secret)) = cookie.value().split_once(':') {
|
|
if let Ok(user_id) = Uuid::parse_str(user_id) {
|
|
sessions.remove(user_id, secret);
|
|
}
|
|
}
|
|
cookies.remove_private(Cookie::build("user_id"));
|
|
}
|
|
Status::NoContent
|
|
}
|
|
|
|
// Add auth guard
|
|
pub struct AuthenticatedUser {
|
|
pub user_id: Uuid,
|
|
}
|
|
|
|
#[rocket::async_trait]
|
|
impl<'r> rocket::request::FromRequest<'r> for AuthenticatedUser {
|
|
type Error = ();
|
|
|
|
async fn from_request(
|
|
request: &'r rocket::Request<'_>,
|
|
) -> rocket::request::Outcome<Self, Self::Error> {
|
|
use rocket::request::Outcome;
|
|
|
|
let sessions = request.rocket().state::<SessionStore>().unwrap();
|
|
|
|
match request.cookies().get_private("user_id") {
|
|
Some(cookie) => {
|
|
if let Some((user_id, secret)) = cookie.value().split_once(':') {
|
|
if let Ok(user_id) = Uuid::parse_str(user_id) {
|
|
if sessions.verify(user_id, secret) {
|
|
return Outcome::Success(AuthenticatedUser { user_id });
|
|
}
|
|
}
|
|
}
|
|
Outcome::Forward(Status::Unauthorized)
|
|
}
|
|
None => Outcome::Forward(Status::Unauthorized),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[get("/setup")]
|
|
pub async fn setup_page(mut db: Connection<Db>) -> Result<Template, Status> {
|
|
// Check if any users exist
|
|
let count = sqlx::query!("SELECT COUNT(*) as count FROM users")
|
|
.fetch_one(&mut **db)
|
|
.await
|
|
.map_err(|_| Status::InternalServerError)?
|
|
.count;
|
|
|
|
if count > 0 {
|
|
// If users exist, redirect to login
|
|
Err(Status::SeeOther)
|
|
} else {
|
|
// Show setup page
|
|
Ok(Template::render("setup", context! {}))
|
|
}
|
|
}
|
|
|
|
#[post("/setup", data = "<new_user>")]
|
|
pub async fn setup(
|
|
mut db: Connection<Db>,
|
|
new_user: Json<NewUser>,
|
|
) -> Result<Status, Json<SetupError>> {
|
|
let new_user = new_user.into_inner();
|
|
|
|
// Check if any users exist
|
|
let count = sqlx::query!("SELECT COUNT(*) as count FROM users")
|
|
.fetch_one(&mut **db)
|
|
.await
|
|
.map_err(|e| {
|
|
eprintln!("Database error: {}", e);
|
|
Json(SetupError {
|
|
error: "Internal server error".to_string(),
|
|
})
|
|
})?
|
|
.count;
|
|
|
|
if count > 0 {
|
|
return Err(Json(SetupError {
|
|
error: "Setup has already been completed".to_string(),
|
|
}));
|
|
}
|
|
|
|
let password = new_user.password.as_bytes();
|
|
|
|
// Hash the password
|
|
let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST).map_err(|e| {
|
|
eprintln!("Password hashing error: {}", e);
|
|
Json(SetupError {
|
|
error: "Failed to process password".to_string(),
|
|
})
|
|
})?;
|
|
|
|
let mut user = User::new(
|
|
new_user.username,
|
|
password_hash,
|
|
new_user.email,
|
|
new_user.display_name,
|
|
);
|
|
user.admin = true; // This is an admin user
|
|
|
|
match user.write_to_database(&mut **db).await {
|
|
Ok(_) => Ok(Status::Created),
|
|
Err(e) => {
|
|
eprintln!("Database error: {}", e);
|
|
match e {
|
|
sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
|
|
Err(Json(SetupError {
|
|
error: "Username already exists".to_string(),
|
|
}))
|
|
}
|
|
_ => Err(Json(SetupError {
|
|
error: "Failed to create user".to_string(),
|
|
})),
|
|
}
|
|
}
|
|
}
|
|
}
|