375 lines
12 KiB
Rust
375 lines
12 KiB
Rust
use axum::extract::ws::Message;
|
|
use axum::extract::ws::WebSocket;
|
|
use axum::extract::WebSocketUpgrade;
|
|
use axum::headers;
|
|
use axum::http::StatusCode;
|
|
use axum::response::IntoResponse;
|
|
use axum::routing::get;
|
|
use axum::routing::post;
|
|
use axum::Extension;
|
|
use axum::Json;
|
|
use axum::Router;
|
|
use axum::TypedHeader;
|
|
use clap::command;
|
|
use clap::Parser;
|
|
use common::Achievement;
|
|
use common::CreateAchievement;
|
|
use common::CreateMilestone;
|
|
use common::DeleteAchievement;
|
|
use common::DeleteMilestone;
|
|
use common::Milestone;
|
|
use common::RestResponse;
|
|
use common::ToggleAchievement;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use tokio::fs;
|
|
use tokio_stream::StreamExt;
|
|
use tower::ServiceBuilder;
|
|
use tower_http::cors::CorsLayer;
|
|
use tower_http::trace::DefaultMakeSpan;
|
|
use tower_http::trace::TraceLayer;
|
|
|
|
const APP_STATE_FILE: &str = "achievements.json";
|
|
|
|
type SharedState = Arc<tokio::sync::RwLock<SharedStateParts>>;
|
|
struct SharedStateParts {
|
|
app_state: AppState,
|
|
watcher_tx: tokio::sync::watch::Sender<common::State>,
|
|
}
|
|
|
|
type Response<T> = Result<(StatusCode, Json<RestResponse<T>>), HandlerError>;
|
|
// TODO: still needed?
|
|
#[derive(Debug, thiserror::Error)]
|
|
enum HandlerError {
|
|
// #[error("Failed to lock state")]
|
|
// LockAppStateError,
|
|
}
|
|
impl IntoResponse for HandlerError {
|
|
fn into_response(self) -> axum::response::Response {
|
|
let error_message = format!("{self}");
|
|
(StatusCode::INTERNAL_SERVER_ERROR, error_message).into_response()
|
|
}
|
|
}
|
|
|
|
#[derive(clap::Parser)]
|
|
#[command(author, version, about, long_about = None)]
|
|
struct Opts {
|
|
#[clap(short, long, default_value = "127.0.0.1")]
|
|
address: String,
|
|
|
|
#[clap(short, long, default_value = "4000")]
|
|
port: u16,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
tracing_subscriber::fmt::init();
|
|
|
|
let opts: Opts = Opts::parse();
|
|
|
|
tracing::info!("Address: {}", opts.address);
|
|
tracing::info!("Port: {}", opts.port);
|
|
|
|
let socket = format!("{}:{}", opts.address, opts.port);
|
|
let socket_addr = socket.parse::<SocketAddr>().unwrap();
|
|
|
|
let init_app_state = match AppState::read_state().await {
|
|
Ok(state) => state,
|
|
Err(AppStateReadError::FileReadError(_)) => {
|
|
tracing::info!(
|
|
"Could not load previous state from {}. Creating new default state.",
|
|
APP_STATE_FILE
|
|
);
|
|
AppState::default()
|
|
}
|
|
Err(e) => panic!("Unexpected error: {:?}", e),
|
|
};
|
|
let (app_state_watch_tx, app_state_watch_rx) = tokio::sync::watch::channel(init_app_state.state.clone());
|
|
let app_state: SharedState = Arc::new(tokio::sync::RwLock::new(SharedStateParts {
|
|
app_state: init_app_state,
|
|
watcher_tx: app_state_watch_tx,
|
|
}));
|
|
|
|
let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
|
|
tokio::spawn(async move {
|
|
shutdown_signal().await;
|
|
shutdown_tx.send(true).unwrap();
|
|
});
|
|
|
|
// Start a separate Tokio task for periodically saving the app state
|
|
let save_task = {
|
|
let app_state = Arc::clone(&app_state);
|
|
let mut shutdown_rx = shutdown_rx.clone();
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
_ = tokio::time::sleep(Duration::from_secs(10 * 60)) => {
|
|
tracing::debug!("Saving state.");
|
|
let lock = app_state.write().await;
|
|
if let Err(err) = lock.app_state.write_state().await {
|
|
tracing::error!("Failed to write app state: {err}");
|
|
}
|
|
}
|
|
_ = shutdown_rx.changed() => {
|
|
tracing::info!("Stopping save task.");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
})
|
|
};
|
|
|
|
let server_task = {
|
|
let app_state = Arc::clone(&app_state);
|
|
tokio::spawn(async move {
|
|
let app = Router::new()
|
|
.route("/api/v1/create", post(create_achievement))
|
|
.route("/api/v1/delete", post(delete_achievement))
|
|
.route("/api/v1/toggle", post(toggle_achievement))
|
|
.route("/api/v1/create-milestone", post(create_milestone))
|
|
.route("/api/v1/delete-milestone", post(delete_milestone))
|
|
.route("/api/ws", get(ws_handler))
|
|
.layer(
|
|
ServiceBuilder::new()
|
|
.layer(Extension(app_state))
|
|
.layer(Extension(app_state_watch_rx)),
|
|
)
|
|
.layer(
|
|
TraceLayer::new_for_http()
|
|
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
|
|
)
|
|
.layer(CorsLayer::permissive());
|
|
|
|
tracing::debug!("listening on {}", socket_addr);
|
|
let server = axum::Server::bind(&socket_addr)
|
|
.serve(app.into_make_service())
|
|
.with_graceful_shutdown(async {
|
|
shutdown_rx.changed().await.unwrap();
|
|
tracing::info!("Starting graceful server shutdown.");
|
|
});
|
|
|
|
if let Err(err) = server.await {
|
|
eprintln!("Server error: {err:?}");
|
|
}
|
|
})
|
|
};
|
|
|
|
// Wait for all tasks to finish.
|
|
let _ = tokio::join!(server_task, save_task);
|
|
|
|
// Save final app state
|
|
tracing::info!("Writing app state to disk.");
|
|
let lock = app_state.write().await;
|
|
lock.app_state.write_state().await.unwrap();
|
|
|
|
tracing::info!("Shutdown.");
|
|
}
|
|
|
|
async fn ws_handler(
|
|
ws: WebSocketUpgrade,
|
|
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
|
Extension(state_watch_rx): Extension<tokio::sync::watch::Receiver<common::State>>,
|
|
) -> impl IntoResponse {
|
|
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
|
user_agent.to_string()
|
|
} else {
|
|
String::from("Unknown browser")
|
|
};
|
|
tracing::debug!("New ws connection from: {user_agent}.");
|
|
|
|
ws.on_upgrade(move |socket| handle_socket(socket, state_watch_rx))
|
|
}
|
|
|
|
/// Websocket statemachine (one will be spawned per connection)
|
|
async fn handle_socket(mut socket: WebSocket, state_watch_rx: tokio::sync::watch::Receiver<common::State>) {
|
|
let mut stream = tokio_stream::wrappers::WatchStream::new(state_watch_rx);
|
|
while let Some(state) = stream.next().await {
|
|
let state: common::WebSocketMessage = state;
|
|
let serialized = serde_json::to_string(&state).expect("Failed to serialize app state to JSON");
|
|
if socket.send(Message::Text(serialized)).await.is_err() {
|
|
tracing::debug!("Websocket client disconnected");
|
|
break;
|
|
}
|
|
}
|
|
|
|
tracing::debug!("Websocket context destroyed");
|
|
}
|
|
|
|
async fn create_achievement(
|
|
Extension(app_state): Extension<SharedState>,
|
|
Json(create_achievement): Json<CreateAchievement>,
|
|
) -> Response<()> {
|
|
tracing::debug!("Creating achievement: {create_achievement:?}.");
|
|
let achievement = Achievement {
|
|
goal: create_achievement.goal,
|
|
completed: false,
|
|
uuid: uuid::Uuid::new_v4(),
|
|
};
|
|
let mut lock = app_state.write().await;
|
|
lock.app_state.state.achievements.push(achievement);
|
|
lock.watcher_tx
|
|
.send(lock.app_state.state.clone())
|
|
.expect("watch channel is closed, every receiver was dropped.");
|
|
Ok((StatusCode::CREATED, Json(Ok(()))))
|
|
}
|
|
|
|
async fn create_milestone(
|
|
Extension(app_state): Extension<SharedState>,
|
|
Json(create_milestone): Json<CreateMilestone>,
|
|
) -> Response<()> {
|
|
tracing::debug!("Creating milestone: {create_milestone:?}.");
|
|
|
|
let goal_max = 100;
|
|
if create_milestone.goal > goal_max {
|
|
return Ok((
|
|
StatusCode::BAD_REQUEST,
|
|
Json(Err(format!("Max goal allowed: {goal_max}.").into())),
|
|
));
|
|
}
|
|
|
|
let milestone = Milestone {
|
|
goal: create_milestone.goal,
|
|
uuid: uuid::Uuid::new_v4(),
|
|
};
|
|
let mut lock = app_state.write().await;
|
|
lock.app_state.state.milestones.push(milestone);
|
|
lock.watcher_tx
|
|
.send(lock.app_state.state.clone())
|
|
.expect("watch channel is closed, every receiver was dropped.");
|
|
Ok((StatusCode::CREATED, Json(Ok(()))))
|
|
}
|
|
|
|
async fn delete_milestone(
|
|
Extension(app_state): Extension<SharedState>,
|
|
Json(delete_milestone): Json<DeleteMilestone>,
|
|
) -> Response<()> {
|
|
tracing::debug!("Deleting milestone: {delete_milestone:?}.");
|
|
let mut lock = app_state.write().await;
|
|
if let Some(pos) = lock
|
|
.app_state
|
|
.state
|
|
.milestones
|
|
.iter()
|
|
.position(|x| x.uuid == delete_milestone.uuid)
|
|
{
|
|
lock.app_state.state.milestones.remove(pos);
|
|
lock.watcher_tx
|
|
.send(lock.app_state.state.clone())
|
|
.expect("watch channel is closed, every receiver was dropped.");
|
|
}
|
|
Ok((StatusCode::OK, Json(Ok(()))))
|
|
}
|
|
|
|
async fn toggle_achievement(
|
|
Extension(app_state): Extension<SharedState>,
|
|
Json(toggle_achievement): Json<ToggleAchievement>,
|
|
) -> Response<()> {
|
|
tracing::debug!("Toggling achievement: {toggle_achievement:?}.");
|
|
let mut lock = app_state.write().await;
|
|
if let Some(achievement) = lock
|
|
.app_state
|
|
.state
|
|
.achievements
|
|
.iter_mut()
|
|
.find(|x| x.uuid == toggle_achievement.uuid)
|
|
{
|
|
achievement.completed = !achievement.completed;
|
|
lock.watcher_tx
|
|
.send(lock.app_state.state.clone())
|
|
.expect("watch channel is closed, every receiver was dropped.");
|
|
}
|
|
Ok((StatusCode::OK, Json(Ok(()))))
|
|
}
|
|
|
|
async fn delete_achievement(
|
|
Extension(app_state): Extension<SharedState>,
|
|
Json(delete_achievement): Json<DeleteAchievement>,
|
|
) -> Response<()> {
|
|
tracing::debug!("Deleting achievement: {delete_achievement:?}.");
|
|
let mut lock = app_state.write().await;
|
|
if let Some(pos) = lock
|
|
.app_state
|
|
.state
|
|
.achievements
|
|
.iter()
|
|
.position(|x| x.uuid == delete_achievement.uuid)
|
|
{
|
|
lock.app_state.state.achievements.remove(pos);
|
|
lock.watcher_tx
|
|
.send(lock.app_state.state.clone())
|
|
.expect("watch channel is closed, every receiver was dropped.");
|
|
}
|
|
Ok((StatusCode::OK, Json(Ok(()))))
|
|
}
|
|
|
|
async fn shutdown_signal() {
|
|
let ctrl_c = async {
|
|
tokio::signal::ctrl_c()
|
|
.await
|
|
.expect("failed to install Ctrl+C handler");
|
|
};
|
|
|
|
#[cfg(unix)]
|
|
let terminate = async {
|
|
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
|
.expect("failed to install signal handler")
|
|
.recv()
|
|
.await;
|
|
};
|
|
|
|
// #[cfg(not(unix))]
|
|
// let terminate = std::future::pending::<()>();
|
|
|
|
tokio::select! {
|
|
_ = ctrl_c => {},
|
|
_ = terminate => {},
|
|
}
|
|
|
|
tracing::info!("Shutdown signal received.");
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
|
struct AppState {
|
|
state: common::State,
|
|
}
|
|
|
|
impl AppState {
|
|
/// Load.
|
|
async fn read_state() -> Result<Self, AppStateReadError> {
|
|
let file = fs::read_to_string(APP_STATE_FILE)
|
|
.await
|
|
.map_err(AppStateReadError::FileReadError)?;
|
|
let result = serde_json::from_str(&file)?;
|
|
Ok(result)
|
|
}
|
|
|
|
/// Save.
|
|
async fn write_state(&self) -> Result<(), AppStateWriteError> {
|
|
let serialized = serde_json::to_string(&self)?;
|
|
fs::write(APP_STATE_FILE, serialized)
|
|
.await
|
|
.map_err(AppStateWriteError::FileWriteError)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum AppStateReadError {
|
|
#[error("Failed to read the state file")]
|
|
FileReadError(std::io::Error),
|
|
|
|
#[error("Failed to deserialize the state")]
|
|
DeserializationError(#[from] serde_json::Error),
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum AppStateWriteError {
|
|
#[error("Failed to write the state file")]
|
|
FileWriteError(std::io::Error),
|
|
|
|
#[error("Failed to serialize the state")]
|
|
SerializationError(#[from] serde_json::Error),
|
|
}
|