375 lines
12 KiB
Rust

use axum::extract::ws::Message;
use axum::extract::ws::WebSocket;
use axum::extract::WebSocketUpgrade;
use axum::headers;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::routing::post;
use axum::Extension;
use axum::Json;
use axum::Router;
use axum::TypedHeader;
use clap::command;
use clap::Parser;
use common::Achievement;
use common::CreateAchievement;
use common::CreateMilestone;
use common::DeleteAchievement;
use common::DeleteMilestone;
use common::Milestone;
use common::RestResponse;
use common::ToggleAchievement;
use serde::Deserialize;
use serde::Serialize;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::fs;
use tokio_stream::StreamExt;
use tower::ServiceBuilder;
use tower_http::cors::CorsLayer;
use tower_http::trace::DefaultMakeSpan;
use tower_http::trace::TraceLayer;
const APP_STATE_FILE: &str = "achievements.json";
type SharedState = Arc<tokio::sync::RwLock<SharedStateParts>>;
struct SharedStateParts {
app_state: AppState,
watcher_tx: tokio::sync::watch::Sender<common::State>,
}
type Response<T> = Result<(StatusCode, Json<RestResponse<T>>), HandlerError>;
// TODO: still needed?
#[derive(Debug, thiserror::Error)]
enum HandlerError {
// #[error("Failed to lock state")]
// LockAppStateError,
}
impl IntoResponse for HandlerError {
fn into_response(self) -> axum::response::Response {
let error_message = format!("{self}");
(StatusCode::INTERNAL_SERVER_ERROR, error_message).into_response()
}
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Opts {
#[clap(short, long, default_value = "127.0.0.1")]
address: String,
#[clap(short, long, default_value = "4000")]
port: u16,
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let opts: Opts = Opts::parse();
tracing::info!("Address: {}", opts.address);
tracing::info!("Port: {}", opts.port);
let socket = format!("{}:{}", opts.address, opts.port);
let socket_addr = socket.parse::<SocketAddr>().unwrap();
let init_app_state = match AppState::read_state().await {
Ok(state) => state,
Err(AppStateReadError::FileReadError(_)) => {
tracing::info!(
"Could not load previous state from {}. Creating new default state.",
APP_STATE_FILE
);
AppState::default()
}
Err(e) => panic!("Unexpected error: {:?}", e),
};
let (app_state_watch_tx, app_state_watch_rx) = tokio::sync::watch::channel(init_app_state.state.clone());
let app_state: SharedState = Arc::new(tokio::sync::RwLock::new(SharedStateParts {
app_state: init_app_state,
watcher_tx: app_state_watch_tx,
}));
let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
tokio::spawn(async move {
shutdown_signal().await;
shutdown_tx.send(true).unwrap();
});
// Start a separate Tokio task for periodically saving the app state
let save_task = {
let app_state = Arc::clone(&app_state);
let mut shutdown_rx = shutdown_rx.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(10 * 60)) => {
tracing::debug!("Saving state.");
let lock = app_state.write().await;
if let Err(err) = lock.app_state.write_state().await {
tracing::error!("Failed to write app state: {err}");
}
}
_ = shutdown_rx.changed() => {
tracing::info!("Stopping save task.");
break;
}
}
}
})
};
let server_task = {
let app_state = Arc::clone(&app_state);
tokio::spawn(async move {
let app = Router::new()
.route("/api/v1/create", post(create_achievement))
.route("/api/v1/delete", post(delete_achievement))
.route("/api/v1/toggle", post(toggle_achievement))
.route("/api/v1/create-milestone", post(create_milestone))
.route("/api/v1/delete-milestone", post(delete_milestone))
.route("/api/ws", get(ws_handler))
.layer(
ServiceBuilder::new()
.layer(Extension(app_state))
.layer(Extension(app_state_watch_rx)),
)
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
)
.layer(CorsLayer::permissive());
tracing::debug!("listening on {}", socket_addr);
let server = axum::Server::bind(&socket_addr)
.serve(app.into_make_service())
.with_graceful_shutdown(async {
shutdown_rx.changed().await.unwrap();
tracing::info!("Starting graceful server shutdown.");
});
if let Err(err) = server.await {
eprintln!("Server error: {err:?}");
}
})
};
// Wait for all tasks to finish.
let _ = tokio::join!(server_task, save_task);
// Save final app state
tracing::info!("Writing app state to disk.");
let lock = app_state.write().await;
lock.app_state.write_state().await.unwrap();
tracing::info!("Shutdown.");
}
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Extension(state_watch_rx): Extension<tokio::sync::watch::Receiver<common::State>>,
) -> impl IntoResponse {
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
tracing::debug!("New ws connection from: {user_agent}.");
ws.on_upgrade(move |socket| handle_socket(socket, state_watch_rx))
}
/// Websocket statemachine (one will be spawned per connection)
async fn handle_socket(mut socket: WebSocket, state_watch_rx: tokio::sync::watch::Receiver<common::State>) {
let mut stream = tokio_stream::wrappers::WatchStream::new(state_watch_rx);
while let Some(state) = stream.next().await {
let state: common::WebSocketMessage = state;
let serialized = serde_json::to_string(&state).expect("Failed to serialize app state to JSON");
if socket.send(Message::Text(serialized)).await.is_err() {
tracing::debug!("Websocket client disconnected");
break;
}
}
tracing::debug!("Websocket context destroyed");
}
async fn create_achievement(
Extension(app_state): Extension<SharedState>,
Json(create_achievement): Json<CreateAchievement>,
) -> Response<()> {
tracing::debug!("Creating achievement: {create_achievement:?}.");
let achievement = Achievement {
goal: create_achievement.goal,
completed: false,
uuid: uuid::Uuid::new_v4(),
};
let mut lock = app_state.write().await;
lock.app_state.state.achievements.push(achievement);
lock.watcher_tx
.send(lock.app_state.state.clone())
.expect("watch channel is closed, every receiver was dropped.");
Ok((StatusCode::CREATED, Json(Ok(()))))
}
async fn create_milestone(
Extension(app_state): Extension<SharedState>,
Json(create_milestone): Json<CreateMilestone>,
) -> Response<()> {
tracing::debug!("Creating milestone: {create_milestone:?}.");
let goal_max = 100;
if create_milestone.goal > goal_max {
return Ok((
StatusCode::BAD_REQUEST,
Json(Err(format!("Max goal allowed: {goal_max}.").into())),
));
}
let milestone = Milestone {
goal: create_milestone.goal,
uuid: uuid::Uuid::new_v4(),
};
let mut lock = app_state.write().await;
lock.app_state.state.milestones.push(milestone);
lock.watcher_tx
.send(lock.app_state.state.clone())
.expect("watch channel is closed, every receiver was dropped.");
Ok((StatusCode::CREATED, Json(Ok(()))))
}
async fn delete_milestone(
Extension(app_state): Extension<SharedState>,
Json(delete_milestone): Json<DeleteMilestone>,
) -> Response<()> {
tracing::debug!("Deleting milestone: {delete_milestone:?}.");
let mut lock = app_state.write().await;
if let Some(pos) = lock
.app_state
.state
.milestones
.iter()
.position(|x| x.uuid == delete_milestone.uuid)
{
lock.app_state.state.milestones.remove(pos);
lock.watcher_tx
.send(lock.app_state.state.clone())
.expect("watch channel is closed, every receiver was dropped.");
}
Ok((StatusCode::OK, Json(Ok(()))))
}
async fn toggle_achievement(
Extension(app_state): Extension<SharedState>,
Json(toggle_achievement): Json<ToggleAchievement>,
) -> Response<()> {
tracing::debug!("Toggling achievement: {toggle_achievement:?}.");
let mut lock = app_state.write().await;
if let Some(achievement) = lock
.app_state
.state
.achievements
.iter_mut()
.find(|x| x.uuid == toggle_achievement.uuid)
{
achievement.completed = !achievement.completed;
lock.watcher_tx
.send(lock.app_state.state.clone())
.expect("watch channel is closed, every receiver was dropped.");
}
Ok((StatusCode::OK, Json(Ok(()))))
}
async fn delete_achievement(
Extension(app_state): Extension<SharedState>,
Json(delete_achievement): Json<DeleteAchievement>,
) -> Response<()> {
tracing::debug!("Deleting achievement: {delete_achievement:?}.");
let mut lock = app_state.write().await;
if let Some(pos) = lock
.app_state
.state
.achievements
.iter()
.position(|x| x.uuid == delete_achievement.uuid)
{
lock.app_state.state.achievements.remove(pos);
lock.watcher_tx
.send(lock.app_state.state.clone())
.expect("watch channel is closed, every receiver was dropped.");
}
Ok((StatusCode::OK, Json(Ok(()))))
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
// #[cfg(not(unix))]
// let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
tracing::info!("Shutdown signal received.");
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
struct AppState {
state: common::State,
}
impl AppState {
/// Load.
async fn read_state() -> Result<Self, AppStateReadError> {
let file = fs::read_to_string(APP_STATE_FILE)
.await
.map_err(AppStateReadError::FileReadError)?;
let result = serde_json::from_str(&file)?;
Ok(result)
}
/// Save.
async fn write_state(&self) -> Result<(), AppStateWriteError> {
let serialized = serde_json::to_string(&self)?;
fs::write(APP_STATE_FILE, serialized)
.await
.map_err(AppStateWriteError::FileWriteError)
}
}
#[derive(Debug, thiserror::Error)]
pub enum AppStateReadError {
#[error("Failed to read the state file")]
FileReadError(std::io::Error),
#[error("Failed to deserialize the state")]
DeserializationError(#[from] serde_json::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum AppStateWriteError {
#[error("Failed to write the state file")]
FileWriteError(std::io::Error),
#[error("Failed to serialize the state")]
SerializationError(#[from] serde_json::Error),
}