diff --git a/crates/backend/src/admin.rs b/crates/backend/src/admin.rs index 944d8dd..443ce65 100644 --- a/crates/backend/src/admin.rs +++ b/crates/backend/src/admin.rs @@ -28,8 +28,7 @@ async fn auth_middleware( req: Request, next: Next, ) -> Response { - let admin_tok = headers.get("x-admin-tok"); - if headers.get("x-admin-tok") == Some(&HeaderValue::from_static(ADMIN_TOK)) { + if state.check_admin_tok(&headers) { let res = next.run(req).await; return res; } diff --git a/crates/backend/src/chat.rs b/crates/backend/src/chat.rs index a1e791b..349b7ac 100644 --- a/crates/backend/src/chat.rs +++ b/crates/backend/src/chat.rs @@ -114,25 +114,21 @@ pub async fn post( return StatusCode::BAD_REQUEST.into_response(); } - if headers.get("x-admin-tok") == Some(&HeaderValue::from_static(ADMIN_TOK)) { - sqlx::query!( - r#"insert into messages (chat_id, content, from_admin) values ($1, $2, true);"#, - chat.id, - body - ) - .execute(state.pool()) - .await - .unwrap(); + let is_admin = state.check_admin_tok(&headers); + + sqlx::query!( + r#"insert into messages (chat_id, content, from_admin) values ($1, $2, $3);"#, + chat.id, + body, + is_admin + ) + .execute(state.pool()) + .await + .unwrap(); + + if is_admin { StatusCode::OK.into_response() } else { - sqlx::query!( - r#"insert into messages (chat_id, content, from_admin) values ($1, $2, false);"#, - chat.id, - body - ) - .execute(state.pool()) - .await - .unwrap(); Redirect::to(&format!("/{url_path}")).into_response() } } diff --git a/crates/backend/src/state.rs b/crates/backend/src/state.rs index 31f7f51..8bdb3c3 100644 --- a/crates/backend/src/state.rs +++ b/crates/backend/src/state.rs @@ -1,5 +1,8 @@ +use std::sync::Arc; + use axum::response::IntoResponse; -use http::StatusCode; +use config::Config; +use http::{HeaderMap, HeaderValue, StatusCode}; use sqlx::{Pool, Postgres}; use thiserror::Error; @@ -9,23 +12,85 @@ type Result = std::result::Result; const DB_URL: &str = "postgres://chatdings@localhost/chatdings"; #[derive(Debug, Clone)] -pub struct AppState { +pub struct AppState(Arc); + +#[derive(Debug)] +pub struct AppStateInner { pool: Pool, + config: Config, +} + +mod config { + use std::{env, fs}; + + #[derive(Debug)] + pub(super) struct Config { + db_url: String, + admin_token: String, + } + + impl Config { + pub fn read_from_env() -> Self { + let db_url = env::var("CHATDINGS_DB_URL") + .expect("environment variable CHATDINGS_DB_URL has to exist"); + let admin_token_path = env::var("CHATDINGS_ADMIN_TOKEN_PATH") + .expect("environment variable CHATDINGS_ADMIN_TOKEN has to exist"); + let admin_token = fs::read_to_string(&admin_token_path) + .expect(&format!("failed to read '{admin_token_path:?}'")) + .trim() + .to_string(); + + Self { + db_url, + admin_token, + } + } + + pub fn db_url(&self) -> &str { + &self.db_url + } + + pub fn admin_tok(&self) -> &str { + &self.admin_token + } + } } impl AppState { pub async fn init() -> Result { - let pool = Pool::::connect(DB_URL).await?; + Ok(Self(Arc::new(AppStateInner::init().await?))) + } + pub async fn fetch_chat_by_url_path(&self, url_path: &str) -> Result { + self.0.fetch_chat_by_url_path(url_path).await + } + pub async fn fetch_messages(&self, chat: &Chat) -> Result> { + self.0.fetch_messages(chat).await + } + pub async fn send_message(&self, chat: &Chat, content: String, from_admin: bool) -> Result<()> { + self.0.send_message(chat, content, from_admin).await + } + pub fn check_admin_tok(&self, headers: &HeaderMap) -> bool { + self.0.check_admin_tok(headers) + } + pub fn pool(&self) -> &Pool { + self.0.pool() + } +} + +impl AppStateInner { + async fn init() -> Result { + let config = Config::read_from_env(); + let pool = Pool::::connect(config.db_url()).await?; sqlx::migrate!() .run(&pool) .await .expect("migration should not fail"); - Ok(Self { pool }) + Ok(Self { pool, config }) } - pub async fn fetch_chat_by_url_path(&self, url_path: &str) -> Result { + async fn fetch_chat_by_url_path(&self, url_path: &str) -> Result { Ok( sqlx::query_as!(Chat, r#"select * from chats where url_path = $1"#, url_path) .fetch_one(&self.pool) @@ -33,7 +98,7 @@ impl AppState { ) } - pub async fn fetch_messages(&self, chat: &Chat) -> Result> { + async fn fetch_messages(&self, chat: &Chat) -> Result> { Ok(sqlx::query_as!( Message, r#"select * from messages where chat_id = $1"#, @@ -43,11 +108,15 @@ impl AppState { .await?) } - pub async fn send_message(&self, chat: &Chat, content: String, from_admin: bool) -> Result<()> { + async fn send_message(&self, chat: &Chat, content: String, from_admin: bool) -> Result<()> { todo!() } - pub fn pool(&self) -> &Pool { + fn check_admin_tok(&self, headers: &HeaderMap) -> bool { + headers.get("x-admin-tok") == Some(&HeaderValue::from_str(self.config.admin_tok()).unwrap()) + } + + fn pool(&self) -> &Pool { &self.pool } } diff --git a/flake.nix b/flake.nix index 822a22a..a1180c0 100644 --- a/flake.nix +++ b/flake.nix @@ -59,8 +59,10 @@ listen_addresses = "127.0.0.1"; }; - env = { + env = rec { DATABASE_URL = "postgres://localhost/chatdings"; + CHATDINGS_DB_URL = DATABASE_URL; + CHATDINGS_ADMIN_TOKEN_PATH = "./test_admin_tok.txt"; }; }) ]; diff --git a/test_admin_tok.txt b/test_admin_tok.txt new file mode 100644 index 0000000..48d07ba --- /dev/null +++ b/test_admin_tok.txt @@ -0,0 +1 @@ +miauu