BrainMinder/cmd/web/middleware.go

139 lines
3.2 KiB
Go
Raw Permalink Normal View History

2024-08-22 10:13:16 +02:00
package main
import (
"fmt"
"log/slog"
"net/http"
"brainminder.speedtech.it/internal/response"
"github.com/justinas/nosurf"
"github.com/tomasen/realip"
)
func (app *application) recoverPanic(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
err := recover()
if err != nil {
app.serverError(w, r, fmt.Errorf("%s", err))
}
}()
next.ServeHTTP(w, r)
})
}
func (app *application) securityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Referrer-Policy", "origin-when-cross-origin")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "deny")
w.Header().Set("Service-Worker-Allowed", "/")
next.ServeHTTP(w, r)
})
}
func (app *application) logAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mw := response.NewMetricsResponseWriter(w)
next.ServeHTTP(mw, r)
var (
ip = realip.FromRequest(r)
method = r.Method
url = r.URL.String()
proto = r.Proto
)
userAttrs := slog.Group("user", "ip", ip)
requestAttrs := slog.Group("request", "method", method, "url", url, "proto", proto)
responseAttrs := slog.Group("response", "status", mw.StatusCode, "size", mw.BytesCount)
app.logger.Info("access", userAttrs, requestAttrs, responseAttrs)
})
}
func (app *application) preventCSRF(next http.Handler) http.Handler {
csrfHandler := nosurf.New(next)
csrfHandler.SetBaseCookie(http.Cookie{
HttpOnly: true,
Path: "/",
MaxAge: 86400,
SameSite: http.SameSiteLaxMode,
Secure: true,
})
return csrfHandler
}
func (app *application) authenticate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := app.sessionStore.Get(r, "session")
if err != nil {
app.serverError(w, r, err)
return
}
userID, ok := session.Values["userID"].(int)
if ok {
user, found, err := app.db.GetUser(userID)
if err != nil {
app.serverError(w, r, err)
return
}
if found {
r = contextSetAuthenticatedUser(r, user)
}
}
next.ServeHTTP(w, r)
})
}
func (app *application) requireAuthenticatedUser(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authenticatedUser := contextGetAuthenticatedUser(r)
if authenticatedUser == nil {
session, err := app.sessionStore.Get(r, "session")
if err != nil {
app.serverError(w, r, err)
return
}
session.Values["redirectPathAfterLogin"] = r.URL.Path
err = session.Save(r, w)
if err != nil {
app.serverError(w, r, err)
return
}
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
w.Header().Add("Cache-Control", "no-store")
next.ServeHTTP(w, r)
})
}
func (app *application) requireAnonymousUser(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authenticatedUser := contextGetAuthenticatedUser(r)
if authenticatedUser != nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
next.ServeHTTP(w, r)
})
}