package handlers import ( "MiauInv/auth" "MiauInv/models" "MiauInv/storage" "log" "net" "net/http" "strconv" "strings" ) type activityResponseWriter struct { http.ResponseWriter statusCode int } func (w *activityResponseWriter) WriteHeader(statusCode int) { w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func (w *activityResponseWriter) Write(data []byte) (int, error) { if w.statusCode == 0 { w.statusCode = http.StatusOK } return w.ResponseWriter.Write(data) } func ActivityMiddleware(entityType string, includeGET bool) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { recorder := &activityResponseWriter{ResponseWriter: w} next.ServeHTTP(recorder, r) statusCode := recorder.statusCode if statusCode == 0 { statusCode = http.StatusOK } if !shouldRecordActivity(r, includeGET) { return } claims, ok := r.Context().Value(auth.UserContextKey).(*auth.Claims) if !ok || claims == nil { return } username := "" if user, err := storage.GetUserById(claims.UserID); err == nil { username = user.Username } RecordActivity(r, claims.UserID, username, activityAction(r.Method, r.URL.Path), entityType, activityEntityID(r), activityDetails(statusCode), statusCode) }) } } func ActivityLog(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims, ok := r.Context().Value(auth.UserContextKey).(*auth.Claims) if !ok || claims == nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } limit := parseBoundedInt(r.URL.Query().Get("limit"), 50, 1, 100) offset := parseBoundedInt(r.URL.Query().Get("offset"), 0, 0, 100000) includeAll := claims.Role == models.RoleAdmin && strings.EqualFold(r.URL.Query().Get("all"), "true") entries, err := storage.ListActivityLogs(claims.UserID, includeAll, limit, offset) if err != nil { log.Println("GET [api/activity] " + r.RemoteAddr + ": " + err.Error()) http.Error(w, "Could not load activity log", http.StatusInternalServerError) return } writeJSON(w, http.StatusOK, map[string]interface{}{ "activity": entries, "limit": limit, "offset": offset, "all": includeAll, }) } func RecordActivity(r *http.Request, userID, username, action, entityType, entityID, details string, statusCode int) { if statusCode == 0 { statusCode = http.StatusOK } entry := models.ActivityLogEntry{ UserID: userID, Username: truncateForActivity(username, 120), Action: truncateForActivity(action, 120), EntityType: truncateForActivity(entityType, 80), EntityID: truncateForActivity(entityID, 120), Details: truncateForActivity(details, 500), Method: r.Method, Path: truncateForActivity(r.URL.Path, 255), StatusCode: statusCode, Success: statusCode >= 200 && statusCode < 400, IPAddress: truncateForActivity(clientIP(r), 80), UserAgent: truncateForActivity(r.UserAgent(), 500), } if err := storage.AddActivityLog(entry); err != nil { log.Println("ACTIVITY " + r.RemoteAddr + ": " + err.Error()) } } func shouldRecordActivity(r *http.Request, includeGET bool) bool { if r.Method == http.MethodOptions || r.Method == http.MethodHead { return false } if includeGET { return true } return r.Method != http.MethodGet } func activityAction(method, path string) string { path = strings.TrimSuffix(path, "/") switch { case path == "/api/logout": return "auth.logout" case path == "/api/account/username": return "account.username.update" case path == "/api/account/password": return "account.password.update" case path == "/api/2fa/setup": return "security.2fa.setup" case path == "/api/2fa/enable": return "security.2fa.enable" case path == "/api/2fa/disable": return "security.2fa.disable" case path == "/api/2fa/recovery-codes/regenerate": return "security.2fa.recovery_codes.regenerate" case path == "/api/passkeys/register/options": return "security.passkey.registration.start" case path == "/api/passkeys/register/finish": return "security.passkey.registration.finish" case path == "/api/passkeys/disable": return "security.passkey.disable" case path == "/api/passkeys": if method == http.MethodDelete { return "security.passkey.delete" } return "security.passkey.read" } switch method { case http.MethodPost: return "inventory.create" case http.MethodPut: return "inventory.update" case http.MethodDelete: return "inventory.delete" case http.MethodGet: return "inventory.read" default: return strings.ToLower(method) } } func activityEntityID(r *http.Request) string { if id := strings.TrimSpace(r.URL.Query().Get("id")); id != "" { return id } return "" } func activityDetails(statusCode int) string { if statusCode >= 200 && statusCode < 400 { return "Request completed successfully." } return http.StatusText(statusCode) } func clientIP(r *http.Request) string { if forwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); forwardedFor != "" { parts := strings.Split(forwardedFor, ",") return strings.TrimSpace(parts[0]) } if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); realIP != "" { return realIP } host, _, err := net.SplitHostPort(r.RemoteAddr) if err == nil { return host } return r.RemoteAddr } func parseBoundedInt(raw string, fallback, min, max int) int { if strings.TrimSpace(raw) == "" { return fallback } value, err := strconv.Atoi(raw) if err != nil { return fallback } if value < min { return min } if value > max { return max } return value } func truncateForActivity(value string, max int) string { value = strings.TrimSpace(value) runes := []rune(value) if max <= 0 || len(runes) <= max { return value } return string(runes[:max]) }