Files
MiauInv/auth/middleware.go

80 lines
1.9 KiB
Go

package auth
import (
"context"
"net/http"
"strings"
)
type contextKey string
const UserContextKey contextKey = contextKey("user")
// middleware.go
func AuthMiddleware(secret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/login" || r.URL.Path == "/register" || r.URL.Path == "/" {
next.ServeHTTP(w, r)
return
}
tokenStr := ""
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
tokenStr = strings.TrimPrefix(authHeader, "Bearer ")
}
if tokenStr == "" {
cookie, err := r.Cookie("access_token")
if err == nil {
tokenStr = cookie.Value
}
}
if tokenStr == "" {
if strings.HasPrefix(r.URL.Path, "/api/") {
http.Error(w, "Missing token", http.StatusUnauthorized)
} else {
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
return
}
claims, err := ValidateJWT(tokenStr, secret)
if err != nil {
if strings.HasPrefix(r.URL.Path, "/api/") {
http.Error(w, "Invalid token", http.StatusUnauthorized)
} else {
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: false,
})
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
return
}
ctx := context.WithValue(r.Context(), UserContextKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := r.Context().Value(UserContextKey).(*Claims)
if claims.Role != role {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}