feat: initial REST endpoints and websocket

/create
/delete
/ws
This commit is contained in:
2023-06-10 18:54:23 +02:00
parent 5501b63379
commit 6e845907d1
9 changed files with 1476 additions and 3 deletions

View File

@@ -5,3 +5,15 @@ authors.workspace = true
edition.workspace = true
[dependencies]
serde.workspace = true
axum = { version = "0.6", features = ["ws", "headers"] }
serde_json = "1"
tokio = { version = "1", features = ["full"] }
tracing = "0.1"
tracing-subscriber = "0.3"
common.workspace = true
thiserror = "1.0.40"
tower = "0.4.13"
uuid.workspace = true
tower-http = { version = "0.4.0", features = ["fs", "trace"] }
tokio-stream = { version = "0.1.14", features = ["sync"] }

View File

@@ -1 +1,275 @@
fn main() {}
use axum::extract::ws::Message;
use axum::extract::ws::WebSocket;
use axum::extract::WebSocketUpgrade;
use axum::headers;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::routing::post;
use axum::Extension;
use axum::Json;
use axum::Router;
use axum::TypedHeader;
use common::Achievement;
use common::CreateAchievement;
use common::DeleteAchievement;
use serde::Deserialize;
use serde::Serialize;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::fs;
use tokio_stream::StreamExt;
use tower::ServiceBuilder;
use tower_http::trace::DefaultMakeSpan;
use tower_http::trace::TraceLayer;
const APP_STATE_FILE: &str = "achievements.json";
type SharedState = Arc<tokio::sync::RwLock<SharedStateParts>>;
struct SharedStateParts {
app_state: AppState,
watcher_tx: tokio::sync::watch::Sender<AppState>,
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let init_app_state = match AppState::read_state().await {
Ok(state) => state,
Err(AppStateReadError::FileReadError(_)) => {
tracing::info!(
"Could not load previous state from {}. Creating new default state.",
APP_STATE_FILE
);
AppState::default()
}
Err(e) => panic!("Unexpected error: {:?}", e),
};
let (app_state_watch_tx, app_state_watch_rx) = tokio::sync::watch::channel(init_app_state.clone());
let app_state: SharedState = Arc::new(tokio::sync::RwLock::new(SharedStateParts {
app_state: init_app_state,
watcher_tx: app_state_watch_tx,
}));
let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
tokio::spawn(async move {
shutdown_signal().await;
shutdown_tx.send(true).unwrap();
});
// Start a separate Tokio task for periodically saving the app state
let save_task = {
let app_state = Arc::clone(&app_state);
let mut shutdown_rx = shutdown_rx.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(10 * 60)) => {
tracing::debug!("Saving state.");
let lock = app_state.write().await;
if let Err(err) = lock.app_state.write_state().await {
tracing::error!("Failed to write app state: {err}");
}
}
_ = shutdown_rx.changed() => {
tracing::info!("Stopping save task.");
break;
}
}
}
})
};
let server_task = {
let app_state = Arc::clone(&app_state);
tokio::spawn(async move {
let app = Router::new()
.route("/create", post(create_achievement))
.route("/delete", post(delete_achievement))
.route("/ws", get(ws_handler))
.layer(
ServiceBuilder::new()
.layer(Extension(app_state))
.layer(Extension(app_state_watch_rx)),
)
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
);
let addr = SocketAddr::from(([127, 0, 0, 1], 4000));
tracing::debug!("listening on {}", addr);
let server = axum::Server::bind(&addr)
.serve(app.into_make_service())
.with_graceful_shutdown(async {
shutdown_rx.changed().await.unwrap();
tracing::info!("Starting graceful server shutdown.");
});
if let Err(err) = server.await {
eprintln!("Server error: {err:?}");
}
})
};
// Wait for all tasks to finish.
let _ = tokio::join!(server_task, save_task);
// Save final app state
tracing::info!("Writing app state to disk.");
let lock = app_state.write().await;
lock.app_state.write_state().await.unwrap();
tracing::info!("Shutdown.");
}
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Extension(app_state_watch_rx): Extension<tokio::sync::watch::Receiver<AppState>>,
) -> impl IntoResponse {
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
tracing::debug!("{user_agent} connected websocket.");
ws.on_upgrade(move |socket| handle_socket(socket, app_state_watch_rx))
}
/// Websocket statemachine (one will be spawned per connection)
async fn handle_socket(mut socket: WebSocket, app_state_watch_rx: tokio::sync::watch::Receiver<AppState>) {
let mut stream = tokio_stream::wrappers::WatchStream::new(app_state_watch_rx);
loop {
let app_state = stream.next().await;
let serialized = serde_json::to_string(&app_state).expect("Failed to serialize app state to JSON");
if socket.send(Message::Text(serialized)).await.is_err() {
tracing::debug!("Websocket client disconnected");
break;
}
}
tracing::debug!("Websocket context destroyed");
}
async fn create_achievement(
Extension(app_state): Extension<SharedState>,
Json(create_achievement): Json<CreateAchievement>,
) -> Result<(StatusCode, ()), HandlerError> {
let achievement = Achievement {
goal: create_achievement.goal,
completed: false,
uuid: uuid::Uuid::new_v4(),
};
let mut lock = app_state.write().await;
lock.app_state.achievements.push(achievement);
lock.watcher_tx
.send(lock.app_state.clone())
.expect("watch channel is closed, every receiver was dropped.");
Ok((StatusCode::CREATED, ()))
}
async fn delete_achievement(
Extension(app_state): Extension<SharedState>,
Json(delete_achievement): Json<DeleteAchievement>,
) -> Result<(StatusCode, ()), HandlerError> {
let mut lock = app_state.write().await;
if let Some(pos) = lock
.app_state
.achievements
.iter()
.position(|x| x.uuid == delete_achievement.uuid)
{
lock.app_state.achievements.remove(pos);
lock.watcher_tx
.send(lock.app_state.clone())
.expect("watch channel is closed, every receiver was dropped.");
}
Ok((StatusCode::OK, ()))
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
// #[cfg(not(unix))]
// let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
tracing::info!("Shutdown signal received.");
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
struct AppState {
achievements: Vec<Achievement>,
}
impl AppState {
/// Load.
async fn read_state() -> Result<Self, AppStateReadError> {
let file = fs::read_to_string(APP_STATE_FILE)
.await
.map_err(AppStateReadError::FileReadError)?;
let result = serde_json::from_str(&file)?;
Ok(result)
}
/// Save.
async fn write_state(&self) -> Result<(), AppStateWriteError> {
let serialized = serde_json::to_string(&self)?;
fs::write(APP_STATE_FILE, serialized)
.await
.map_err(AppStateWriteError::FileWriteError)
}
}
#[derive(Debug, thiserror::Error)]
pub enum AppStateReadError {
#[error("Failed to read the state file")]
FileReadError(std::io::Error),
#[error("Failed to deserialize the state")]
DeserializationError(#[from] serde_json::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum AppStateWriteError {
#[error("Failed to write the state file")]
FileWriteError(std::io::Error),
#[error("Failed to serialize the state")]
SerializationError(#[from] serde_json::Error),
}
// TODO: still needed?
#[derive(Debug, thiserror::Error)]
enum HandlerError {
// #[error("Failed to lock state")]
// LockAppStateError,
}
impl IntoResponse for HandlerError {
fn into_response(self) -> axum::response::Response {
let error_message = format!("{self}");
(StatusCode::INTERNAL_SERVER_ERROR, error_message).into_response()
}
}

View File

@@ -5,3 +5,5 @@ authors.workspace = true
edition.workspace = true
[dependencies]
serde.workspace = true
uuid.workspace = true

View File

@@ -1 +1,19 @@
use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Achievement {
pub goal: String,
pub completed: bool,
pub uuid: uuid::Uuid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateAchievement {
pub goal: String,
}
#[derive(Debug, Serialize, Clone, Deserialize)]
pub struct DeleteAchievement {
pub uuid: uuid::Uuid,
}