package auth import ( "context" "net/http" "strings" ) type contextKey string const UserContextKey contextKey = contextKey("user") func AuthMiddleware(secret []byte) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenStr := "" authHeader := r.Header.Get("Authorization") if strings.HasPrefix(authHeader, "Bearer ") { tokenStr = strings.TrimPrefix(authHeader, "Bearer ") } if tokenStr == "" { cookie, err := r.Cookie("access_token") if err == nil { tokenStr = cookie.Value } } if tokenStr == "" { if strings.HasPrefix(r.URL.Path, "/api/") { http.Error(w, "Missing token", http.StatusUnauthorized) } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } return } claims, err := ValidateJWT(tokenStr, secret) if err != nil { if strings.HasPrefix(r.URL.Path, "/api/") { http.Error(w, "Invalid token", http.StatusUnauthorized) } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } return } ctx := context.WithValue( r.Context(), UserContextKey, claims, ) next.ServeHTTP(w, r.WithContext(ctx)) }) } } func RequireRole(role string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := r.Context().Value(UserContextKey).(*Claims) if claims.Role != role { http.Error(w, "Forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } }