Files
MiauInv/handlers/activity.go
2026-06-10 14:17:33 +02:00

224 lines
5.8 KiB
Go

package handlers
import (
"MiauInv/auth"
"MiauInv/models"
"MiauInv/storage"
"log"
"net"
"net/http"
"strconv"
"strings"
)
type activityResponseWriter struct {
http.ResponseWriter
statusCode int
}
func (w *activityResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *activityResponseWriter) Write(data []byte) (int, error) {
if w.statusCode == 0 {
w.statusCode = http.StatusOK
}
return w.ResponseWriter.Write(data)
}
func ActivityMiddleware(entityType string, includeGET bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
recorder := &activityResponseWriter{ResponseWriter: w}
next.ServeHTTP(recorder, r)
statusCode := recorder.statusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
if !shouldRecordActivity(r, includeGET) {
return
}
claims, ok := r.Context().Value(auth.UserContextKey).(*auth.Claims)
if !ok || claims == nil {
return
}
username := ""
if user, err := storage.GetUserById(claims.UserID); err == nil {
username = user.Username
}
RecordActivity(r, claims.UserID, username, activityAction(r.Method, r.URL.Path), entityType, activityEntityID(r), activityDetails(statusCode), statusCode)
})
}
}
func ActivityLog(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims, ok := r.Context().Value(auth.UserContextKey).(*auth.Claims)
if !ok || claims == nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
limit := parseBoundedInt(r.URL.Query().Get("limit"), 50, 1, 100)
offset := parseBoundedInt(r.URL.Query().Get("offset"), 0, 0, 100000)
includeAll := claims.Role == models.RoleAdmin && strings.EqualFold(r.URL.Query().Get("all"), "true")
entries, err := storage.ListActivityLogs(claims.UserID, includeAll, limit, offset)
if err != nil {
log.Println("GET [api/activity] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not load activity log", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"activity": entries,
"limit": limit,
"offset": offset,
"all": includeAll,
})
}
func RecordActivity(r *http.Request, userID, username, action, entityType, entityID, details string, statusCode int) {
if statusCode == 0 {
statusCode = http.StatusOK
}
entry := models.ActivityLogEntry{
UserID: userID,
Username: truncateForActivity(username, 120),
Action: truncateForActivity(action, 120),
EntityType: truncateForActivity(entityType, 80),
EntityID: truncateForActivity(entityID, 120),
Details: truncateForActivity(details, 500),
Method: r.Method,
Path: truncateForActivity(r.URL.Path, 255),
StatusCode: statusCode,
Success: statusCode >= 200 && statusCode < 400,
IPAddress: truncateForActivity(clientIP(r), 80),
UserAgent: truncateForActivity(r.UserAgent(), 500),
}
if err := storage.AddActivityLog(entry); err != nil {
log.Println("ACTIVITY " + r.RemoteAddr + ": " + err.Error())
}
}
func shouldRecordActivity(r *http.Request, includeGET bool) bool {
if r.Method == http.MethodOptions || r.Method == http.MethodHead {
return false
}
if includeGET {
return true
}
return r.Method != http.MethodGet
}
func activityAction(method, path string) string {
path = strings.TrimSuffix(path, "/")
switch {
case path == "/api/logout":
return "auth.logout"
case path == "/api/account/username":
return "account.username.update"
case path == "/api/account/password":
return "account.password.update"
case path == "/api/2fa/setup":
return "security.2fa.setup"
case path == "/api/2fa/enable":
return "security.2fa.enable"
case path == "/api/2fa/disable":
return "security.2fa.disable"
case path == "/api/2fa/recovery-codes/regenerate":
return "security.2fa.recovery_codes.regenerate"
case path == "/api/passkeys/register/options":
return "security.passkey.registration.start"
case path == "/api/passkeys/register/finish":
return "security.passkey.registration.finish"
case path == "/api/passkeys/disable":
return "security.passkey.disable"
case path == "/api/passkeys":
if method == http.MethodDelete {
return "security.passkey.delete"
}
return "security.passkey.read"
}
switch method {
case http.MethodPost:
return "inventory.create"
case http.MethodPut:
return "inventory.update"
case http.MethodDelete:
return "inventory.delete"
case http.MethodGet:
return "inventory.read"
default:
return strings.ToLower(method)
}
}
func activityEntityID(r *http.Request) string {
if id := strings.TrimSpace(r.URL.Query().Get("id")); id != "" {
return id
}
return ""
}
func activityDetails(statusCode int) string {
if statusCode >= 200 && statusCode < 400 {
return "Request completed successfully."
}
return http.StatusText(statusCode)
}
func clientIP(r *http.Request) string {
if forwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); forwardedFor != "" {
parts := strings.Split(forwardedFor, ",")
return strings.TrimSpace(parts[0])
}
if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); realIP != "" {
return realIP
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
return host
}
return r.RemoteAddr
}
func parseBoundedInt(raw string, fallback, min, max int) int {
if strings.TrimSpace(raw) == "" {
return fallback
}
value, err := strconv.Atoi(raw)
if err != nil {
return fallback
}
if value < min {
return min
}
if value > max {
return max
}
return value
}
func truncateForActivity(value string, max int) string {
value = strings.TrimSpace(value)
runes := []rune(value)
if max <= 0 || len(runes) <= max {
return value
}
return string(runes[:max])
}