Files
MiauInv/storage/storage.go
2026-06-10 14:17:33 +02:00

461 lines
12 KiB
Go

package storage
import (
"MiauInv/models"
utils "MiauInv/util"
"database/sql"
"errors"
"log"
"strings"
"time"
_ "github.com/glebarez/go-sqlite"
)
var ErrNotFound = sql.ErrNoRows
var DB *sql.DB
func InitDB(filepath string) error {
var err error
DB, err = sql.Open("sqlite", filepath)
if err != nil {
return err
}
schema := `
PRAGMA foreign_keys = ON;
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
role TEXT NOT NULL,
two_factor_enabled INTEGER NOT NULL DEFAULT 0,
two_factor_secret TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
token_hash TEXT NOT NULL,
expires_at INTEGER NOT NULL,
created_at INTEGER NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
device_info TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS two_factor_recovery_codes (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
code_hash TEXT NOT NULL,
created_at INTEGER NOT NULL,
used_at INTEGER DEFAULT NULL,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
UNIQUE(user_id, code_hash)
);
CREATE TABLE IF NOT EXISTS passkey_credentials (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
credential_id TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
credential_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
last_used_at INTEGER DEFAULT NULL,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS passkey_challenges (
token TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT '',
ceremony TEXT NOT NULL,
session_data TEXT NOT NULL,
expires_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS activity_logs (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT '',
username TEXT NOT NULL DEFAULT '',
action TEXT NOT NULL,
entity_type TEXT NOT NULL DEFAULT '',
entity_id TEXT NOT NULL DEFAULT '',
details TEXT NOT NULL DEFAULT '',
method TEXT NOT NULL DEFAULT '',
path TEXT NOT NULL DEFAULT '',
status_code INTEGER NOT NULL DEFAULT 0,
success INTEGER NOT NULL DEFAULT 0,
ip_address TEXT NOT NULL DEFAULT '',
user_agent TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
category TEXT,
description TEXT,
total_quantity INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS locations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
description TEXT
);
CREATE TABLE IF NOT EXISTS stock (
id INTEGER PRIMARY KEY AUTOINCREMENT,
item_id INTEGER NOT NULL,
location_id INTEGER NOT NULL,
quantity INTEGER NOT NULL,
FOREIGN KEY(item_id) REFERENCES items(id),
FOREIGN KEY(location_id) REFERENCES locations(id)
);
CREATE TABLE IF NOT EXISTS project_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
item_id INTEGER NOT NULL,
project_id INTEGER NOT NULL,
quantity INTEGER NOT NULL,
FOREIGN KEY(item_id) REFERENCES items(id),
FOREIGN KEY(project_id) REFERENCES projects(id)
);
`
_, err = DB.Exec(schema)
if err != nil {
log.Fatal(err)
}
if err := ensureUserTwoFactorColumns(); err != nil {
return err
}
if err := ensureActivityLogIndexes(); err != nil {
return err
}
return nil
}
func ensureUserTwoFactorColumns() error {
migrations := []string{
"ALTER TABLE users ADD COLUMN two_factor_enabled INTEGER NOT NULL DEFAULT 0",
"ALTER TABLE users ADD COLUMN two_factor_secret TEXT NOT NULL DEFAULT ''",
}
for _, migration := range migrations {
_, err := DB.Exec(migration)
if err != nil && !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
return nil
}
func ensureActivityLogIndexes() error {
indexes := []string{
"CREATE INDEX IF NOT EXISTS idx_activity_logs_user_created ON activity_logs(user_id, created_at DESC)",
"CREATE INDEX IF NOT EXISTS idx_activity_logs_created ON activity_logs(created_at DESC)",
"CREATE INDEX IF NOT EXISTS idx_activity_logs_action ON activity_logs(action)",
}
for _, stmt := range indexes {
if _, err := DB.Exec(stmt); err != nil {
return err
}
}
return nil
}
// Activity logs
func AddActivityLog(entry models.ActivityLogEntry) error {
if DB == nil {
return errors.New("db not initialized")
}
if entry.ID == "" {
entry.ID = utils.GenerateUUID()
}
if entry.CreatedAt == 0 {
entry.CreatedAt = time.Now().Unix()
}
success := 0
if entry.Success {
success = 1
}
_, err := DB.Exec(`
INSERT INTO activity_logs(
id, user_id, username, action, entity_type, entity_id, details,
method, path, status_code, success, ip_address, user_agent, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, entry.ID, entry.UserID, entry.Username, entry.Action, entry.EntityType, entry.EntityID, entry.Details,
entry.Method, entry.Path, entry.StatusCode, success, entry.IPAddress, entry.UserAgent, entry.CreatedAt)
return err
}
func ListActivityLogs(userID string, includeAll bool, limit, offset int) ([]models.ActivityLogEntry, error) {
if DB == nil {
return nil, errors.New("db not initialized")
}
if limit <= 0 {
limit = 50
}
if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
query := `
SELECT id, user_id, username, action, entity_type, entity_id, details,
method, path, status_code, success, ip_address, user_agent, created_at
FROM activity_logs
`
args := []interface{}{}
if !includeAll {
query += " WHERE user_id = ?"
args = append(args, userID)
}
query += " ORDER BY created_at DESC, id DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := DB.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
logs := []models.ActivityLogEntry{}
for rows.Next() {
var entry models.ActivityLogEntry
var success int
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Username, &entry.Action, &entry.EntityType, &entry.EntityID, &entry.Details,
&entry.Method, &entry.Path, &entry.StatusCode, &success, &entry.IPAddress, &entry.UserAgent, &entry.CreatedAt,
); err != nil {
return nil, err
}
entry.Success = success == 1
logs = append(logs, entry)
}
return logs, rows.Err()
}
// Users
func AddUser(user *models.User) error {
_, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, strings.ToLower(user.Username), user.Password, user.Role)
return err
}
func GetUserByUsername(username string) (models.User, error) {
row := DB.QueryRow(`
SELECT id, username, password, role, two_factor_enabled, two_factor_secret
FROM users
WHERE username = ?
`, strings.ToLower(username))
return scanUser(row)
}
func GetUserById(id string) (models.User, error) {
row := DB.QueryRow(`
SELECT id, username, password, role, two_factor_enabled, two_factor_secret
FROM users
WHERE id = ?
`, id)
return scanUser(row)
}
func scanUser(row *sql.Row) (models.User, error) {
var user models.User
var twoFactorEnabled int
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role, &twoFactorEnabled, &user.TwoFactorSecret)
user.TwoFactorEnabled = twoFactorEnabled == 1
return user, err
}
func UpdateUserUsername(userID, username string) error {
res, err := DB.Exec("UPDATE users SET username = ? WHERE id = ?", strings.ToLower(username), userID)
if err != nil {
return err
}
n, err := res.RowsAffected()
if err != nil {
return err
}
if n == 0 {
return ErrNotFound
}
return nil
}
func UpdateUserPassword(userID, passwordHash string) error {
res, err := DB.Exec("UPDATE users SET password = ? WHERE id = ?", passwordHash, userID)
if err != nil {
return err
}
n, err := res.RowsAffected()
if err != nil {
return err
}
if n == 0 {
return ErrNotFound
}
return nil
}
func SetUserTwoFactorSecret(userID, secret string) error {
_, err := DB.Exec("UPDATE users SET two_factor_secret = ? WHERE id = ?", secret, userID)
return err
}
func EnableUserTwoFactorWithSecretAndRecoveryCodes(userID, twoFactorSecret string, recoveryCodeHashes []string) error {
tx, err := DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
return err
}
now := time.Now().Unix()
for _, codeHash := range recoveryCodeHashes {
if _, err := tx.Exec(`
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
VALUES (?, ?, ?, ?)
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
return err
}
}
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 1, two_factor_secret = ? WHERE id = ?", twoFactorSecret, userID); err != nil {
return err
}
return tx.Commit()
}
func DisableUserTwoFactor(userID string) error {
tx, err := DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 0, two_factor_secret = '' WHERE id = ?", userID); err != nil {
return err
}
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
return err
}
return tx.Commit()
}
func ReplaceUserRecoveryCodes(userID string, recoveryCodeHashes []string) error {
tx, err := DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
return err
}
now := time.Now().Unix()
for _, codeHash := range recoveryCodeHashes {
if _, err := tx.Exec(`
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
VALUES (?, ?, ?, ?)
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
return err
}
}
return tx.Commit()
}
func UseUserRecoveryCode(userID, codeHash string) (bool, error) {
res, err := DB.Exec(`
UPDATE two_factor_recovery_codes
SET used_at = ?
WHERE user_id = ? AND code_hash = ? AND used_at IS NULL
`, time.Now().Unix(), userID, codeHash)
if err != nil {
return false, err
}
n, err := res.RowsAffected()
if err != nil {
return false, err
}
return n == 1, nil
}
func CountUnusedRecoveryCodes(userID string) (int, error) {
var count int
err := DB.QueryRow(`
SELECT COUNT(*)
FROM two_factor_recovery_codes
WHERE user_id = ? AND used_at IS NULL
`, userID).Scan(&count)
return count, err
}
// Refresh Tokens
func AddRefreshToken(token *models.RefreshToken) error {
_, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",
token.ID, token.UserID, token.Token, token.ExpiresAt, token.CreatedAt, token.Revoked, token.DeviceInfo)
return err
}
func GetRefreshToken(token string) (models.RefreshToken, error) {
row := DB.QueryRow("SELECT * FROM refresh_tokens WHERE token_hash = ?", token)
var refresh_token models.RefreshToken
err := row.Scan(&refresh_token.ID, &refresh_token.UserID, &refresh_token.Token, &refresh_token.ExpiresAt, &refresh_token.CreatedAt, &refresh_token.Revoked, &refresh_token.DeviceInfo)
return refresh_token, err
}
func RevokeRefreshToken(tokenID string) error {
if DB == nil {
return errors.New("db not initialized")
}
res, err := DB.Exec(`
UPDATE refresh_tokens
SET revoked = 1
WHERE id = ?
`, tokenID)
if err != nil {
return err
}
n, err := res.RowsAffected()
if err != nil {
return err
}
if n == 0 {
return ErrNotFound
}
return nil
}
func RevokeAllRefreshTokensForUser(userID string) error {
if DB == nil {
return errors.New("db not initialized")
}
_, err := DB.Exec(`
UPDATE refresh_tokens
SET revoked = 1
WHERE user_id = ?
`, userID)
return err
}