From ea8ea45c4c4131b501047b27b06dccc3667324e3 Mon Sep 17 00:00:00 2001 From: miaurizius Date: Tue, 9 Jun 2026 22:50:29 +0200 Subject: [PATCH] started with 2fa support --- auth/jwt.go | 49 +++- deploy.sh | 5 - frontend/assets/js/login.js | 75 +++-- frontend/htmx/login.html | 10 +- go.mod | 2 + go.sum | 11 + handlers/account.go | 568 +++++++++++++++++++++++++++++------- models/dbmodels.go | 10 +- server/server.go | 7 +- storage/storage.go | 162 +++++++++- 10 files changed, 757 insertions(+), 142 deletions(-) delete mode 100755 deploy.sh diff --git a/auth/jwt.go b/auth/jwt.go index c09ae61..51cd772 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "time" "github.com/golang-jwt/jwt/v5" @@ -12,6 +13,12 @@ type Claims struct { jwt.RegisteredClaims } +type PurposeClaims struct { + UserID string `json:"user_id"` + Purpose string `json:"purpose"` + jwt.RegisteredClaims +} + func GenerateJWT(userID, role string, secret []byte) (string, error) { claims := Claims{ UserID: userID, @@ -25,8 +32,26 @@ func GenerateJWT(userID, role string, secret []byte) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(secret) } + +func GeneratePurposeJWT(userID, purpose string, secret []byte, ttl time.Duration) (string, error) { + claims := PurposeClaims{ + UserID: userID, + Purpose: purpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(secret) +} + func ValidateJWT(tokenStr string, secret []byte) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { + if token.Method != jwt.SigningMethodHS256 { + return nil, errors.New("unexpected signing method") + } return secret, nil }) if err != nil { @@ -35,7 +60,29 @@ func ValidateJWT(tokenStr string, secret []byte) (*Claims, error) { claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { - return nil, err + return nil, errors.New("invalid token") + } + + return claims, nil +} + +func ValidatePurposeJWT(tokenStr, expectedPurpose string, secret []byte) (*PurposeClaims, error) { + token, err := jwt.ParseWithClaims(tokenStr, &PurposeClaims{}, func(token *jwt.Token) (interface{}, error) { + if token.Method != jwt.SigningMethodHS256 { + return nil, errors.New("unexpected signing method") + } + return secret, nil + }) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(*PurposeClaims) + if !ok || !token.Valid { + return nil, errors.New("invalid token") + } + if claims.Purpose != expectedPurpose { + return nil, errors.New("invalid token purpose") } return claims, nil diff --git a/deploy.sh b/deploy.sh deleted file mode 100755 index c28e6c0..0000000 --- a/deploy.sh +++ /dev/null @@ -1,5 +0,0 @@ -sudo docker buildx build \ - --platform linux/amd64,linux/arm64 \ - -t git.miaurizius.de/miaurizius/miauinv:latest \ - -t git.miaurizius.de/miaurizius/miauinv:v1.0.2 \ - --push . \ No newline at end of file diff --git a/frontend/assets/js/login.js b/frontend/assets/js/login.js index 717eabe..16dbfcf 100644 --- a/frontend/assets/js/login.js +++ b/frontend/assets/js/login.js @@ -2,22 +2,63 @@ document.addEventListener("DOMContentLoaded", () => { const form = document.getElementById("login-form"); const errorBox = document.getElementById("error"); + const usernameInput = document.getElementById("username"); + const passwordInput = document.getElementById("password"); + const twoFactorInput = document.getElementById("two-factor-code"); + const twoFactorGroup = document.getElementById("two-factor-group"); + const submitButton = document.getElementById("login-submit"); + + let pendingTwoFactorToken = null; if (!form) return; + function showError(message) { + errorBox.textContent = message || "Login failed."; + errorBox.style.display = "block"; + } + + function storeTokens(data) { + localStorage.setItem("access_token", data.access_token); + localStorage.setItem("refresh_token", data.refresh_token); + } + + function switchToTwoFactorMode(token) { + pendingTwoFactorToken = token; + usernameInput.disabled = true; + passwordInput.disabled = true; + twoFactorGroup.style.display = "block"; + twoFactorInput.required = true; + twoFactorInput.focus(); + submitButton.textContent = "Verify code"; + } + form.addEventListener("submit", async (e) => { e.preventDefault(); errorBox.style.display = "none"; - - const username = document.getElementById("username").value; - const password = document.getElementById("password").value; + submitButton.disabled = true; try { - const response = await fetch("/api/login", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ username, password }) - }); + let response; + + if (pendingTwoFactorToken) { + response = await fetch("/api/login/2fa", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + two_factor_token: pendingTwoFactorToken, + code: twoFactorInput.value.trim() + }) + }); + } else { + response = await fetch("/api/login", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + username: usernameInput.value, + password: passwordInput.value + }) + }); + } if (!response.ok) { const text = await response.text(); @@ -26,17 +67,17 @@ document.addEventListener("DOMContentLoaded", () => { const data = await response.json(); - localStorage.setItem("access_token", data.access_token); - localStorage.setItem("refresh_token", data.refresh_token); - - document.cookie = `access_token=${data.access_token}; path=/; max-age=900; SameSite=Lax; Secure`; - document.cookie = `refresh_token=${data.refresh_token}; path=/; max-age=604800; SameSite=Lax; Secure`; + if (data.requires_2fa) { + switchToTwoFactorMode(data.two_factor_token); + return; + } + storeTokens(data); window.location.href = "/dashboard"; - } catch (err) { - errorBox.textContent = err.message || "Login failed."; - errorBox.style.display = "block"; + showError(err.message); + } finally { + submitButton.disabled = false; } }); -}); \ No newline at end of file +}); diff --git a/frontend/htmx/login.html b/frontend/htmx/login.html index 6fbbf8d..dc88210 100644 --- a/frontend/htmx/login.html +++ b/frontend/htmx/login.html @@ -24,12 +24,18 @@ -
+
- + + +
diff --git a/go.mod b/go.mod index b8461b1..db300f5 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,14 @@ require ( github.com/glebarez/go-sqlite v1.22.0 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/uuid v1.5.0 + github.com/pquerna/otp v1.5.0 github.com/tdewolff/minify/v2 v2.24.13 golang.org/x/crypto v0.52.0 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect diff --git a/go.sum b/go.sum index 9a82097..f8fc0a3 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= @@ -10,8 +14,15 @@ github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/tdewolff/minify/v2 v2.24.13 h1:xrcF7gKDnUszseEY9WX9mUlZII2v2Go/QAcAwRASw58= github.com/tdewolff/minify/v2 v2.24.13/go.mod h1:emvwoYeIl8bfAKqRU5ww95LX9Gpggpqv/naal9a8Yq0= github.com/tdewolff/parse/v2 v2.8.12 h1:5BBjfaCv482v3nltlS0u6wH1xJaxjR6ofDrWttNvROg= diff --git a/handlers/account.go b/handlers/account.go index 4db49fc..283da61 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -2,19 +2,23 @@ package handlers import ( "MiauInv/auth" - "MiauInv/config" "MiauInv/models" "MiauInv/storage" - "MiauInv/util" + utils "MiauInv/util" + "bytes" + "crypto/rand" + "encoding/base64" + "encoding/hex" "encoding/json" + "image/png" "log" "net/http" "os" "strings" "time" -) -var cfg, _ = config.LoadConfig() + "github.com/pquerna/otp/totp" +) func APIRegister(w http.ResponseWriter, r *http.Request) { var user models.User @@ -55,6 +59,7 @@ func APIRegister(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) log.Println("POST [api/register] " + r.RemoteAddr + ": Successfully created user") } + func APILogin(w http.ResponseWriter, r *http.Request) { var creds struct { Username string `json:"username"` @@ -86,76 +91,300 @@ func APILogin(w http.ResponseWriter, r *http.Request) { return } - accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret) - if err != nil { - log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - http.Error(w, "Could not generate token", http.StatusInternalServerError) + if user.TwoFactorEnabled { + twoFactorToken, err := auth.GeneratePurposeJWT(user.ID, "2fa_login", secret, 5*time.Minute) + if err != nil { + log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate 2FA challenge", http.StatusInternalServerError) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "requires_2fa": true, + "two_factor_token": twoFactorToken, + }) + log.Println("POST [api/login] " + r.RemoteAddr + ": Password accepted, 2FA required") return } - refreshTokenPlain, err := utils.GenerateRefreshToken() - if err != nil { - log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - http.Error(w, "could not generate refresh token", http.StatusInternalServerError) - return - } - refreshHash := utils.HashToken(refreshTokenPlain) - refreshID := utils.GenerateUUID() - refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix() // expiry: 7 days - - deviceInfo := r.Header.Get("User-Agent") - - if err := storage.AddRefreshToken(&models.RefreshToken{ - ID: refreshID, - UserID: user.ID, - Token: refreshHash, - ExpiresAt: refreshExpires, - DeviceInfo: deviceInfo, - CreatedAt: time.Now().Unix(), - Revoked: false, - }); err != nil { - log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - http.Error(w, "could not save refresh token", http.StatusInternalServerError) - return - } - - // Return access + refresh token (refresh in plain for client to store securely) - resp := map[string]interface{}{ - "access_token": accessToken, - "refresh_token": refreshTokenPlain, - "user": map[string]interface{}{ - "id": user.ID, - "username": user.Username, - "role": user.Role, - }, - } - - http.SetCookie(w, &http.Cookie{ - Name: "access_token", - Value: accessToken, - Path: "/", - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: refreshTokenPlain, - Path: "/", - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(resp) - if err != nil { - log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - http.Error(w, "Something went wrong", http.StatusInternalServerError) - return - } + issueLoginSession(w, r, user) log.Println("POST [api/login] " + r.RemoteAddr + ": Successfully logged in") } + +func APILoginTwoFactor(w http.ResponseWriter, r *http.Request) { + var req struct { + TwoFactorToken string `json:"two_factor_token"` + Code string `json:"code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + secret := []byte(os.Getenv("JWT_SECRET")) + claims, err := auth.ValidatePurposeJWT(req.TwoFactorToken, "2fa_login", secret) + if err != nil { + log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid or expired 2FA challenge", http.StatusUnauthorized) + return + } + + user, err := storage.GetUserById(claims.UserID) + if err != nil || !user.TwoFactorEnabled || user.TwoFactorSecret == "" { + log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": 2FA not available for user") + http.Error(w, "Invalid 2FA state", http.StatusUnauthorized) + return + } + + code := strings.TrimSpace(req.Code) + validTOTP := totp.Validate(code, user.TwoFactorSecret) + usedRecoveryCode := false + + if !validTOTP { + recoveryCodeHash := utils.HashToken(normalizeRecoveryCode(code)) + usedRecoveryCode, err = storage.UseUserRecoveryCode(user.ID, recoveryCodeHash) + if err != nil { + log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not validate recovery code", http.StatusInternalServerError) + return + } + } + + if !validTOTP && !usedRecoveryCode { + log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": Invalid 2FA or recovery code") + http.Error(w, "Invalid 2FA or recovery code", http.StatusUnauthorized) + return + } + + issueLoginSession(w, r, user) + if usedRecoveryCode { + log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": Successfully logged in with recovery code") + return + } + log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": Successfully logged in with 2FA") +} + +func TwoFactorSetup(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + claims := r.Context().Value(auth.UserContextKey).(*auth.Claims) + user, err := storage.GetUserById(claims.UserID) + if err != nil { + log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "User not found", http.StatusNotFound) + return + } + + if user.TwoFactorEnabled { + http.Error(w, "2FA is already enabled", http.StatusConflict) + return + } + + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: "MiauInv", + AccountName: user.Username, + SecretSize: 20, + }) + if err != nil { + log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate 2FA secret", http.StatusInternalServerError) + return + } + + if err := storage.SetUserTwoFactorSecret(user.ID, key.Secret()); err != nil { + log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not save 2FA secret", http.StatusInternalServerError) + return + } + + img, err := key.Image(220, 220) + if err != nil { + log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate QR code", http.StatusInternalServerError) + return + } + + var qr bytes.Buffer + if err := png.Encode(&qr, img); err != nil { + log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not encode QR code", http.StatusInternalServerError) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "secret": key.Secret(), + "otpauth_url": key.URL(), + "qr_code": "data:image/png;base64," + base64.StdEncoding.EncodeToString(qr.Bytes()), + }) + log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": Created 2FA setup challenge") +} + +func TwoFactorEnable(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Code string `json:"code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + claims := r.Context().Value(auth.UserContextKey).(*auth.Claims) + user, err := storage.GetUserById(claims.UserID) + if err != nil { + log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "User not found", http.StatusNotFound) + return + } + + if user.TwoFactorSecret == "" { + http.Error(w, "2FA setup has not been started", http.StatusBadRequest) + return + } + + if !totp.Validate(strings.TrimSpace(req.Code), user.TwoFactorSecret) { + http.Error(w, "Invalid 2FA code", http.StatusUnauthorized) + return + } + + recoveryCodes, recoveryCodeHashes, err := generateRecoveryCodes(10) + if err != nil { + log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate recovery codes", http.StatusInternalServerError) + return + } + + if err := storage.EnableUserTwoFactorWithRecoveryCodes(user.ID, recoveryCodeHashes); err != nil { + log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not enable 2FA", http.StatusInternalServerError) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "two_factor_enabled": true, + "recovery_codes": recoveryCodes, + }) + log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": Enabled 2FA and generated recovery codes") +} + +func TwoFactorDisable(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Password string `json:"password"` + Code string `json:"code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + claims := r.Context().Value(auth.UserContextKey).(*auth.Claims) + user, err := storage.GetUserById(claims.UserID) + if err != nil { + log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "User not found", http.StatusNotFound) + return + } + + if !auth.CheckPasswordHash(req.Password, user.Password) { + http.Error(w, "Invalid password", http.StatusUnauthorized) + return + } + + if user.TwoFactorEnabled && !totp.Validate(strings.TrimSpace(req.Code), user.TwoFactorSecret) { + http.Error(w, "Invalid 2FA code", http.StatusUnauthorized) + return + } + + if err := storage.DisableUserTwoFactor(user.ID); err != nil { + log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not disable 2FA", http.StatusInternalServerError) + return + } + + if err := storage.RevokeAllRefreshTokensForUser(user.ID); err != nil { + log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not revoke sessions", http.StatusInternalServerError) + return + } + + clearAuthCookies(w) + writeJSON(w, http.StatusOK, map[string]interface{}{"two_factor_enabled": false}) + log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": Disabled 2FA") +} + +func TwoFactorRegenerateRecoveryCodes(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Password string `json:"password"` + Code string `json:"code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + claims := r.Context().Value(auth.UserContextKey).(*auth.Claims) + user, err := storage.GetUserById(claims.UserID) + if err != nil { + log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "User not found", http.StatusNotFound) + return + } + + if !user.TwoFactorEnabled || user.TwoFactorSecret == "" { + http.Error(w, "2FA is not enabled", http.StatusBadRequest) + return + } + + if !auth.CheckPasswordHash(req.Password, user.Password) { + http.Error(w, "Invalid password", http.StatusUnauthorized) + return + } + + if !totp.Validate(strings.TrimSpace(req.Code), user.TwoFactorSecret) { + http.Error(w, "Invalid 2FA code", http.StatusUnauthorized) + return + } + + recoveryCodes, recoveryCodeHashes, err := generateRecoveryCodes(10) + if err != nil { + log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate recovery codes", http.StatusInternalServerError) + return + } + + if err := storage.ReplaceUserRecoveryCodes(user.ID, recoveryCodeHashes); err != nil { + log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not save recovery codes", http.StatusInternalServerError) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "recovery_codes": recoveryCodes, + }) + log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": Regenerated recovery codes") +} + func Logout(w http.ResponseWriter, r *http.Request) { claims := r.Context().Value(auth.UserContextKey).(*auth.Claims) err := storage.RevokeAllRefreshTokensForUser(claims.UserID) @@ -164,8 +393,10 @@ func Logout(w http.ResponseWriter, r *http.Request) { http.Error(w, "Internal server error", http.StatusInternalServerError) return } - w.WriteHeader(204) + clearAuthCookies(w) + w.WriteHeader(http.StatusNoContent) } + func TestHandler(w http.ResponseWriter, r *http.Request) { claims, _ := utils.IsLoggedIn(w, r) @@ -181,13 +412,24 @@ func TestHandler(w http.ResponseWriter, r *http.Request) { } log.Println("GET [api/ping] " + r.RemoteAddr + ": Successfully tested connection") } + func RefreshToken(w http.ResponseWriter, r *http.Request) { var req struct { RefreshToken string `json:"refresh_token"` } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) - http.Error(w, "Invalid request", http.StatusBadRequest) + + if r.Body != nil { + _ = json.NewDecoder(r.Body).Decode(&req) + } + if req.RefreshToken == "" { + cookie, err := r.Cookie("refresh_token") + if err == nil { + req.RefreshToken = cookie.Value + } + } + if req.RefreshToken == "" { + log.Println("POST [api/refresh] " + r.RemoteAddr + ": Missing refresh token") + http.Error(w, "Invalid refresh token", http.StatusUnauthorized) return } @@ -204,43 +446,17 @@ func RefreshToken(w http.ResponseWriter, r *http.Request) { log.Println(err) } - newToken, _ := utils.GenerateRefreshToken() - newHash := utils.HashToken(newToken) - newExpires := time.Now().Add(7 * 24 * time.Hour).Unix() //7 days - newID := utils.GenerateUUID() - deviceInfo := r.Header.Get("User-Agent") - if err = storage.AddRefreshToken(&models.RefreshToken{ - ID: newID, - UserID: tokenRow.UserID, - Token: newHash, - ExpiresAt: newExpires, - CreatedAt: time.Now().Unix(), - Revoked: false, - DeviceInfo: deviceInfo, - }); err != nil { - log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) - http.Error(w, "Could not generate new refresh token", http.StatusInternalServerError) - return - } - user, err := storage.GetUserById(tokenRow.UserID) if err != nil { log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) http.Error(w, "Internal server error", http.StatusInternalServerError) return } - accessToken, _ := auth.GenerateJWT(tokenRow.UserID, user.Role, []byte(os.Getenv("JWT_SECRET"))) - if err = json.NewEncoder(w).Encode(map[string]string{ - "access_token": accessToken, - "refresh_token": newToken, - }); err != nil { - log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } + issueLoginSession(w, r, user) log.Println("POST [api/refresh] " + r.RemoteAddr + ": Successfully refreshed token") } + func UserInfo(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Method " + r.Method + " not allowed") @@ -287,11 +503,21 @@ func UserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "User not found", http.StatusNotFound) return } + + recoveryCodesRemaining := 0 + if user.TwoFactorEnabled { + if count, err := storage.CountUnusedRecoveryCodes(user.ID); err == nil { + recoveryCodesRemaining = count + } + } + w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(map[string]interface{}{ - "id": user.ID, - "username": user.Username, - "avatar_url": "", + "id": user.ID, + "username": user.Username, + "avatar_url": "", + "two_factor_enabled": user.TwoFactorEnabled, + "recovery_codes_remaining": recoveryCodesRemaining, }) if err != nil { log.Println("GET [api/userinfo] " + r.RemoteAddr + ": " + err.Error()) @@ -299,3 +525,137 @@ func UserInfo(w http.ResponseWriter, r *http.Request) { } log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Successfully retrieved user info of " + user.Username + " (" + user.ID + ")") } + +func issueLoginSession(w http.ResponseWriter, r *http.Request, user models.User) { + secret := []byte(os.Getenv("JWT_SECRET")) + if len(secret) == 0 { + log.Println("AUTH " + r.RemoteAddr + ": Server misconfiguration") + http.Error(w, "Server misconfiguration", http.StatusInternalServerError) + return + } + + accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret) + if err != nil { + log.Println("AUTH " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate token", http.StatusInternalServerError) + return + } + + refreshTokenPlain, err := utils.GenerateRefreshToken() + if err != nil { + log.Println("AUTH " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "could not generate refresh token", http.StatusInternalServerError) + return + } + + refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix() + if err := storage.AddRefreshToken(&models.RefreshToken{ + ID: utils.GenerateUUID(), + UserID: user.ID, + Token: utils.HashToken(refreshTokenPlain), + ExpiresAt: refreshExpires, + DeviceInfo: r.Header.Get("User-Agent"), + CreatedAt: time.Now().Unix(), + Revoked: false, + }); err != nil { + log.Println("AUTH " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "could not save refresh token", http.StatusInternalServerError) + return + } + + setAuthCookies(w, accessToken, refreshTokenPlain) + writeJSON(w, http.StatusOK, map[string]interface{}{ + "access_token": accessToken, + "refresh_token": refreshTokenPlain, + "user": map[string]interface{}{ + "id": user.ID, + "username": user.Username, + "role": user.Role, + "two_factor_enabled": user.TwoFactorEnabled, + }, + }) +} + +func generateRecoveryCodes(count int) ([]string, []string, error) { + codes := make([]string, 0, count) + hashes := make([]string, 0, count) + seen := make(map[string]struct{}, count) + + for len(codes) < count { + code, err := generateRecoveryCode() + if err != nil { + return nil, nil, err + } + + normalized := normalizeRecoveryCode(code) + if _, exists := seen[normalized]; exists { + continue + } + seen[normalized] = struct{}{} + + codes = append(codes, code) + hashes = append(hashes, utils.HashToken(normalized)) + } + + return codes, hashes, nil +} + +func generateRecoveryCode() (string, error) { + bytes := make([]byte, 10) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + raw := hex.EncodeToString(bytes) + return raw[0:5] + "-" + raw[5:10] + "-" + raw[10:15] + "-" + raw[15:20], nil +} + +func normalizeRecoveryCode(code string) string { + code = strings.TrimSpace(code) + code = strings.ReplaceAll(code, "-", "") + code = strings.ReplaceAll(code, " ", "") + return strings.ToLower(code) +} + +func setAuthCookies(w http.ResponseWriter, accessToken, refreshToken string) { + http.SetCookie(w, &http.Cookie{ + Name: "access_token", + Value: accessToken, + Path: "/", + MaxAge: 15 * 60, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: refreshToken, + Path: "/", + MaxAge: 7 * 24 * 60 * 60, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearAuthCookies(w http.ResponseWriter) { + for _, name := range []string{"access_token", "refresh_token"} { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + } +} + +func writeJSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + log.Println("JSON response error: " + err.Error()) + } +} diff --git a/models/dbmodels.go b/models/dbmodels.go index dd2f93a..f777642 100644 --- a/models/dbmodels.go +++ b/models/dbmodels.go @@ -1,8 +1,10 @@ package models type User struct { - ID string `json:"id"` - Username string `json:"username"` - Password string `json:"password"` - Role string `json:"role"` + ID string `json:"id"` + Username string `json:"username"` + Password string `json:"password"` + Role string `json:"role"` + TwoFactorEnabled bool `json:"two_factor_enabled"` + TwoFactorSecret string `json:"-"` } diff --git a/server/server.go b/server/server.go index d1fed6b..6391363 100644 --- a/server/server.go +++ b/server/server.go @@ -82,10 +82,15 @@ func (this *Server) Run() { // API // mux.HandleFunc("/api/login", handlers.APILogin) + mux.HandleFunc("/api/login/2fa", handlers.APILoginTwoFactor) mux.HandleFunc("/api/refresh", handlers.RefreshToken) mux.Handle("/api/logout", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.Logout))) mux.Handle("/api/profile", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.UserInfo))) - mux.HandleFunc("/api/userinfo", handlers.UserInfo) + mux.Handle("/api/2fa/setup", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorSetup))) + mux.Handle("/api/2fa/enable", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorEnable))) + mux.Handle("/api/2fa/disable", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorDisable))) + mux.Handle("/api/2fa/recovery-codes/regenerate", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorRegenerateRecoveryCodes))) + mux.Handle("/api/userinfo", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.UserInfo))) if this.AllowRegistration { mux.HandleFunc("/api/register", handlers.APIRegister) } diff --git a/storage/storage.go b/storage/storage.go index 724c1c7..cf40588 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -2,10 +2,12 @@ package storage import ( "MiauInv/models" + utils "MiauInv/util" "database/sql" "errors" "log" "strings" + "time" _ "github.com/glebarez/go-sqlite" ) @@ -27,7 +29,9 @@ func InitDB(filepath string) error { id TEXT PRIMARY KEY, username TEXT NOT NULL UNIQUE, password TEXT NOT NULL, - role TEXT NOT NULL + role TEXT NOT NULL, + two_factor_enabled INTEGER NOT NULL DEFAULT 0, + two_factor_secret TEXT NOT NULL DEFAULT '' ); CREATE TABLE IF NOT EXISTS refresh_tokens ( @@ -41,6 +45,16 @@ func InitDB(filepath string) error { FOREIGN KEY(user_id) REFERENCES users(id) ); + CREATE TABLE IF NOT EXISTS two_factor_recovery_codes ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + code_hash TEXT NOT NULL, + created_at INTEGER NOT NULL, + used_at INTEGER DEFAULT NULL, + FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE(user_id, code_hash) + ); + CREATE TABLE IF NOT EXISTS items ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, @@ -84,7 +98,26 @@ func InitDB(filepath string) error { log.Fatal(err) } - return err + if err := ensureUserTwoFactorColumns(); err != nil { + return err + } + + return nil +} + +func ensureUserTwoFactorColumns() error { + migrations := []string{ + "ALTER TABLE users ADD COLUMN two_factor_enabled INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE users ADD COLUMN two_factor_secret TEXT NOT NULL DEFAULT ''", + } + + for _, migration := range migrations { + _, err := DB.Exec(migration) + if err != nil && !strings.Contains(strings.ToLower(err.Error()), "duplicate column") { + return err + } + } + return nil } // Users @@ -93,18 +126,131 @@ func AddUser(user *models.User) error { return err } func GetUserByUsername(username string) (models.User, error) { - row := DB.QueryRow("SELECT * FROM users WHERE username = ?", strings.ToLower(username)) - var user models.User - err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role) - return user, err + row := DB.QueryRow(` + SELECT id, username, password, role, two_factor_enabled, two_factor_secret + FROM users + WHERE username = ? + `, strings.ToLower(username)) + return scanUser(row) } func GetUserById(id string) (models.User, error) { - row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id) + row := DB.QueryRow(` + SELECT id, username, password, role, two_factor_enabled, two_factor_secret + FROM users + WHERE id = ? + `, id) + return scanUser(row) +} + +func scanUser(row *sql.Row) (models.User, error) { var user models.User - err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role) + var twoFactorEnabled int + err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role, &twoFactorEnabled, &user.TwoFactorSecret) + user.TwoFactorEnabled = twoFactorEnabled == 1 return user, err } +func SetUserTwoFactorSecret(userID, secret string) error { + _, err := DB.Exec("UPDATE users SET two_factor_secret = ? WHERE id = ?", secret, userID) + return err +} + +func EnableUserTwoFactorWithRecoveryCodes(userID string, recoveryCodeHashes []string) error { + tx, err := DB.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil { + return err + } + + now := time.Now().Unix() + for _, codeHash := range recoveryCodeHashes { + if _, err := tx.Exec(` + INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at) + VALUES (?, ?, ?, ?) + `, utils.GenerateUUID(), userID, codeHash, now); err != nil { + return err + } + } + + if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 1 WHERE id = ?", userID); err != nil { + return err + } + + return tx.Commit() +} + +func DisableUserTwoFactor(userID string) error { + tx, err := DB.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 0, two_factor_secret = '' WHERE id = ?", userID); err != nil { + return err + } + if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil { + return err + } + + return tx.Commit() +} + +func ReplaceUserRecoveryCodes(userID string, recoveryCodeHashes []string) error { + tx, err := DB.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil { + return err + } + + now := time.Now().Unix() + for _, codeHash := range recoveryCodeHashes { + if _, err := tx.Exec(` + INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at) + VALUES (?, ?, ?, ?) + `, utils.GenerateUUID(), userID, codeHash, now); err != nil { + return err + } + } + + return tx.Commit() +} + +func UseUserRecoveryCode(userID, codeHash string) (bool, error) { + res, err := DB.Exec(` + UPDATE two_factor_recovery_codes + SET used_at = ? + WHERE user_id = ? AND code_hash = ? AND used_at IS NULL + `, time.Now().Unix(), userID, codeHash) + if err != nil { + return false, err + } + + n, err := res.RowsAffected() + if err != nil { + return false, err + } + return n == 1, nil +} + +func CountUnusedRecoveryCodes(userID string) (int, error) { + var count int + err := DB.QueryRow(` + SELECT COUNT(*) + FROM two_factor_recovery_codes + WHERE user_id = ? AND used_at IS NULL + `, userID).Scan(&count) + return count, err +} + // Refresh Tokens func AddRefreshToken(token *models.RefreshToken) error { _, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",