Commit Diff


commit - c7a127938cc306c8499617dab086101579e2db8f
commit + 7583df586ed40b9202e7585fa5fd795ab2e3315b
blob - 3fdf1a4225655dff12de93a12c6b4d04506fead0
blob + 21363f3dede514363d423a5be1011fbf372824a2
--- web/README.md
+++ web/README.md
@@ -11,5 +11,5 @@
 * [x] Tracing
 * [x] Messages (like flask)
 * [x] Sessions
-* [ ] CSRF
+* [x] CSRF
 * [ ] 404
blob - 007e47f70907b68b74487d36b6e17685fe1d7edf
blob + bfb4153d79077a8e152f7d143adf4d4002ed7a39
--- web/template/Cargo.toml
+++ web/template/Cargo.toml
@@ -8,8 +8,9 @@ edition = "2024"
 
 [dependencies]
 anyhow = "=1.0.100"
-axum = "=0.8.6"
+axum = { version = "=0.8.6", features = ["macros"] }
 axum-messages = "0.8.0"
+axum_csrf = { version = "0.11.0", features = ["layer"] }
 metrics = { version = "=0.24.2", default-features = false }
 metrics-exporter-prometheus = { version = "=0.17.2", default-features = false }
 minijinja = "=2.12.0"
blob - ba99019fcec86494e72815521c26c3c910d9ddf0
blob + 315a365408046be29f9ce0bf5eac5637a77bb219
--- web/template/src/main.rs
+++ web/template/src/main.rs
@@ -40,6 +40,7 @@ async fn start_main_server() -> anyhow::Result<()> {
     env.add_template("home", include_str!("../templates/home.jinja"))?;
     env.add_template("content", include_str!("../templates/content.jinja"))?;
     env.add_template("about", include_str!("../templates/about.jinja"))?;
+    env.add_template("csrf", include_str!("../templates/csrf.jinja"))?;
 
     let app_state = Arc::new(state::AppState { env });
 
blob - aaa8b25965e97f36a3a3fd85cba20832b7651d65
blob + 170890ae66c2802d5e4c9638ba179ea083e46c7b
--- web/template/src/router.rs
+++ web/template/src/router.rs
@@ -16,13 +16,14 @@
 use std::sync::Arc;
 
 use axum::{
-    Router,
+    Form, Router,
     extract::State,
     http::{HeaderName, Request, StatusCode},
     middleware,
     response::{Html, IntoResponse, Redirect},
     routing::get,
 };
+use axum_csrf::{CsrfConfig, CsrfLayer, CsrfToken, Key};
 use axum_messages::{Messages, MessagesManagerLayer};
 use minijinja::context;
 use serde::{Deserialize, Serialize};
@@ -47,10 +48,19 @@ const REQUEST_ID_HEADER: &str = "x-request-id";
 #[derive(Default, Deserialize, Serialize)]
 struct Counter(usize);
 
+#[derive(Deserialize, Serialize)]
+struct Keys {
+    authenticity_token: String,
+}
+
 pub(crate) fn route(app_state: Arc<AppState>) -> Router {
     let x_request_id = HeaderName::from_static(REQUEST_ID_HEADER);
 
     let session_store = MemoryStore::default();
+    let cookie_key = Key::generate();
+    let config = CsrfConfig::default()
+        .with_key(Some(cookie_key))
+        .with_cookie_domain(Some("127.0.0.1"));
 
     Router::new()
         .route("/", get(handler_home))
@@ -59,6 +69,7 @@ pub(crate) fn route(app_state: Arc<AppState>) -> Route
         .route("/session", get(handler_session))
         .route("/message", get(set_messages_handler))
         .route("/read-messages", get(read_messages_handler))
+        .route("/csrf", get(csrf_root).post(csrf_check_key))
         .layer(MessagesManagerLayer)
         // TODO(msi): from config folder asssets
         .nest_service("/assets", ServeDir::new("assets"))
@@ -85,6 +96,7 @@ pub(crate) fn route(app_state: Arc<AppState>) -> Route
                 .with_secure(false)
                 .with_expiry(Expiry::OnInactivity(Duration::seconds(10))),
             MessagesManagerLayer,
+            CsrfLayer::new(config),
             // TODO(msi): from config
             TimeoutLayer::new(std::time::Duration::from_secs(10)),
             PropagateRequestIdLayer::new(x_request_id),
@@ -94,6 +106,34 @@ pub(crate) fn route(app_state: Arc<AppState>) -> Route
         .with_state(app_state)
 }
 
+async fn csrf_root(
+    token: CsrfToken,
+    State(state): State<Arc<AppState>>,
+) -> impl IntoResponse {
+    let template = state.env.get_template("csrf").unwrap();
+
+    let rendered = template
+        .render(context! {
+            title => "Csrf",
+            authenticity_token => token.authenticity_token().unwrap(),
+        })
+        .unwrap();
+    // We must return the token so that into_response will run and add it to our response cookies.
+    (token, Html(rendered)).into_response()
+}
+
+async fn csrf_check_key(
+    token: CsrfToken,
+    Form(payload): Form<Keys>,
+) -> &'static str {
+    // Verfiy the Hash and return the String message.
+    if token.verify(&payload.authenticity_token).is_err() {
+        "Token is invalid"
+    } else {
+        "Token is Valid lets do stuff!"
+    }
+}
+
 async fn set_messages_handler(messages: Messages) -> impl IntoResponse {
     messages.info("Hello, world!").debug("This is a debug message.");
 
blob - /dev/null
blob + 2a9ab2e3e79834a1b4bb83ec45ca782e6b797f6f (mode 644)
--- /dev/null
+++ web/template/templates/csrf.jinja
@@ -0,0 +1,10 @@
+{% extends "layout" %}
+{% block title %}{{ super() }} | {{ title }} {% endblock %}
+{% block body %}
+<h1>{{ title }}</h1>
+<p>{{ about_text }}</p>
+ <form method="post" action="/csrf">
+            <input type="hidden" name="authenticity_token" value="{{ authenticity_token }}"/>
+            <input id="button" type="submit" value="Submit" tabindex="4" />
+        </form>
+{% endblock %}
blob - 946dcdc44e8a01f3853417895e92d21f0c374b05
blob + f162c30c552c43988a707bd39c1cbd88cc1ed4a4
--- web/template/templates/layout.jinja
+++ web/template/templates/layout.jinja
@@ -11,6 +11,7 @@
             <li><a href="/session">Session</a></li>
             <li><a href="/message">Set Message</a></li>
             <li><a href="/read-messages">Read Messages</a></li>
+            <li><a href="/csrf">Csrf</a></li>
         </ul>
     </nav>
     <h1><h1>Hello, World web =]</h1>