use axum::extract::ws::Message; use axum::extract::ws::WebSocket; use axum::extract::WebSocketUpgrade; use axum::headers; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::get; use axum::routing::post; use axum::Extension; use axum::Json; use axum::Router; use axum::TypedHeader; use clap::command; use clap::Parser; use common::Achievement; use common::CreateAchievement; use common::CreateMilestone; use common::DeleteAchievement; use common::DeleteMilestone; use common::Milestone; use common::RestResponse; use common::ToggleAchievement; use serde::Deserialize; use serde::Serialize; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::fs; use tokio_stream::StreamExt; use tower::ServiceBuilder; use tower_http::cors::CorsLayer; use tower_http::trace::DefaultMakeSpan; use tower_http::trace::TraceLayer; const APP_STATE_FILE: &str = "achievements.json"; type SharedState = Arc>; struct SharedStateParts { app_state: AppState, watcher_tx: tokio::sync::watch::Sender, } type Response = Result<(StatusCode, Json>), HandlerError>; // TODO: still needed? #[derive(Debug, thiserror::Error)] enum HandlerError { // #[error("Failed to lock state")] // LockAppStateError, } impl IntoResponse for HandlerError { fn into_response(self) -> axum::response::Response { let error_message = format!("{self}"); (StatusCode::INTERNAL_SERVER_ERROR, error_message).into_response() } } #[derive(clap::Parser)] #[command(author, version, about, long_about = None)] struct Opts { #[clap(short, long, default_value = "127.0.0.1")] address: String, #[clap(short, long, default_value = "4000")] port: u16, } #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); let opts: Opts = Opts::parse(); tracing::info!("Address: {}", opts.address); tracing::info!("Port: {}", opts.port); let socket = format!("{}:{}", opts.address, opts.port); let socket_addr = socket.parse::().unwrap(); let init_app_state = match AppState::read_state().await { Ok(state) => state, Err(AppStateReadError::FileReadError(_)) => { tracing::info!( "Could not load previous state from {}. Creating new default state.", APP_STATE_FILE ); AppState::default() } Err(e) => panic!("Unexpected error: {:?}", e), }; let (app_state_watch_tx, app_state_watch_rx) = tokio::sync::watch::channel(init_app_state.state.clone()); let app_state: SharedState = Arc::new(tokio::sync::RwLock::new(SharedStateParts { app_state: init_app_state, watcher_tx: app_state_watch_tx, })); let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false); tokio::spawn(async move { shutdown_signal().await; shutdown_tx.send(true).unwrap(); }); // Start a separate Tokio task for periodically saving the app state let save_task = { let app_state = Arc::clone(&app_state); let mut shutdown_rx = shutdown_rx.clone(); tokio::spawn(async move { loop { tokio::select! { _ = tokio::time::sleep(Duration::from_secs(10 * 60)) => { tracing::debug!("Saving state."); let lock = app_state.write().await; if let Err(err) = lock.app_state.write_state().await { tracing::error!("Failed to write app state: {err}"); } } _ = shutdown_rx.changed() => { tracing::info!("Stopping save task."); break; } } } }) }; let server_task = { let app_state = Arc::clone(&app_state); tokio::spawn(async move { let app = Router::new() .route("/api/v1/create", post(create_achievement)) .route("/api/v1/delete", post(delete_achievement)) .route("/api/v1/toggle", post(toggle_achievement)) .route("/api/v1/create-milestone", post(create_milestone)) .route("/api/v1/delete-milestone", post(delete_milestone)) .route("/api/ws", get(ws_handler)) .layer( ServiceBuilder::new() .layer(Extension(app_state)) .layer(Extension(app_state_watch_rx)), ) .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::default().include_headers(true)), ) .layer(CorsLayer::permissive()); tracing::debug!("listening on {}", socket_addr); let server = axum::Server::bind(&socket_addr) .serve(app.into_make_service()) .with_graceful_shutdown(async { shutdown_rx.changed().await.unwrap(); tracing::info!("Starting graceful server shutdown."); }); if let Err(err) = server.await { eprintln!("Server error: {err:?}"); } }) }; // Wait for all tasks to finish. let _ = tokio::join!(server_task, save_task); // Save final app state tracing::info!("Writing app state to disk."); let lock = app_state.write().await; lock.app_state.write_state().await.unwrap(); tracing::info!("Shutdown."); } async fn ws_handler( ws: WebSocketUpgrade, user_agent: Option>, Extension(state_watch_rx): Extension>, ) -> impl IntoResponse { let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { user_agent.to_string() } else { String::from("Unknown browser") }; tracing::debug!("New ws connection from: {user_agent}."); ws.on_upgrade(move |socket| handle_socket(socket, state_watch_rx)) } /// Websocket statemachine (one will be spawned per connection) async fn handle_socket(mut socket: WebSocket, state_watch_rx: tokio::sync::watch::Receiver) { let mut stream = tokio_stream::wrappers::WatchStream::new(state_watch_rx); while let Some(state) = stream.next().await { let state: common::WebSocketMessage = state; let serialized = serde_json::to_string(&state).expect("Failed to serialize app state to JSON"); if socket.send(Message::Text(serialized)).await.is_err() { tracing::debug!("Websocket client disconnected"); break; } } tracing::debug!("Websocket context destroyed"); } async fn create_achievement( Extension(app_state): Extension, Json(create_achievement): Json, ) -> Response<()> { tracing::debug!("Creating achievement: {create_achievement:?}."); let achievement = Achievement { goal: create_achievement.goal, completed: false, uuid: uuid::Uuid::new_v4(), }; let mut lock = app_state.write().await; lock.app_state.state.achievements.push(achievement); lock.watcher_tx .send(lock.app_state.state.clone()) .expect("watch channel is closed, every receiver was dropped."); Ok((StatusCode::CREATED, Json(Ok(())))) } async fn create_milestone( Extension(app_state): Extension, Json(create_milestone): Json, ) -> Response<()> { tracing::debug!("Creating milestone: {create_milestone:?}."); let goal_max = 100; if create_milestone.goal > goal_max { return Ok(( StatusCode::BAD_REQUEST, Json(Err(format!("Max goal allowed: {goal_max}.").into())), )); } let milestone = Milestone { goal: create_milestone.goal, uuid: uuid::Uuid::new_v4(), }; let mut lock = app_state.write().await; lock.app_state.state.milestones.push(milestone); lock.watcher_tx .send(lock.app_state.state.clone()) .expect("watch channel is closed, every receiver was dropped."); Ok((StatusCode::CREATED, Json(Ok(())))) } async fn delete_milestone( Extension(app_state): Extension, Json(delete_milestone): Json, ) -> Response<()> { tracing::debug!("Deleting milestone: {delete_milestone:?}."); let mut lock = app_state.write().await; if let Some(pos) = lock .app_state .state .milestones .iter() .position(|x| x.uuid == delete_milestone.uuid) { lock.app_state.state.milestones.remove(pos); lock.watcher_tx .send(lock.app_state.state.clone()) .expect("watch channel is closed, every receiver was dropped."); } Ok((StatusCode::OK, Json(Ok(())))) } async fn toggle_achievement( Extension(app_state): Extension, Json(toggle_achievement): Json, ) -> Response<()> { tracing::debug!("Toggling achievement: {toggle_achievement:?}."); let mut lock = app_state.write().await; if let Some(achievement) = lock .app_state .state .achievements .iter_mut() .find(|x| x.uuid == toggle_achievement.uuid) { achievement.completed = !achievement.completed; lock.watcher_tx .send(lock.app_state.state.clone()) .expect("watch channel is closed, every receiver was dropped."); } Ok((StatusCode::OK, Json(Ok(())))) } async fn delete_achievement( Extension(app_state): Extension, Json(delete_achievement): Json, ) -> Response<()> { tracing::debug!("Deleting achievement: {delete_achievement:?}."); let mut lock = app_state.write().await; if let Some(pos) = lock .app_state .state .achievements .iter() .position(|x| x.uuid == delete_achievement.uuid) { lock.app_state.state.achievements.remove(pos); lock.watcher_tx .send(lock.app_state.state.clone()) .expect("watch channel is closed, every receiver was dropped."); } Ok((StatusCode::OK, Json(Ok(())))) } async fn shutdown_signal() { let ctrl_c = async { tokio::signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; // #[cfg(not(unix))] // let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } tracing::info!("Shutdown signal received."); } #[derive(Clone, Debug, Serialize, Deserialize, Default)] struct AppState { state: common::State, } impl AppState { /// Load. async fn read_state() -> Result { let file = fs::read_to_string(APP_STATE_FILE) .await .map_err(AppStateReadError::FileReadError)?; let result = serde_json::from_str(&file)?; Ok(result) } /// Save. async fn write_state(&self) -> Result<(), AppStateWriteError> { let serialized = serde_json::to_string(&self)?; fs::write(APP_STATE_FILE, serialized) .await .map_err(AppStateWriteError::FileWriteError) } } #[derive(Debug, thiserror::Error)] pub enum AppStateReadError { #[error("Failed to read the state file")] FileReadError(std::io::Error), #[error("Failed to deserialize the state")] DeserializationError(#[from] serde_json::Error), } #[derive(Debug, thiserror::Error)] pub enum AppStateWriteError { #[error("Failed to write the state file")] FileWriteError(std::io::Error), #[error("Failed to serialize the state")] SerializationError(#[from] serde_json::Error), }