4 Commits

Author SHA1 Message Date
ea8ea45c4c started with 2fa support 2026-06-09 22:50:29 +02:00
5485fd135d new version 2026-06-09 14:45:04 +02:00
5558d42bdb fixed #2 2026-06-09 14:44:28 +02:00
b74df36bda fixed #4 2026-06-09 14:40:49 +02:00
11 changed files with 789 additions and 146 deletions

View File

@@ -1,6 +1,7 @@
package auth package auth
import ( import (
"errors"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -12,6 +13,12 @@ type Claims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
} }
type PurposeClaims struct {
UserID string `json:"user_id"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
func GenerateJWT(userID, role string, secret []byte) (string, error) { func GenerateJWT(userID, role string, secret []byte) (string, error) {
claims := Claims{ claims := Claims{
UserID: userID, UserID: userID,
@@ -25,8 +32,26 @@ func GenerateJWT(userID, role string, secret []byte) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret) return token.SignedString(secret)
} }
func GeneratePurposeJWT(userID, purpose string, secret []byte, ttl time.Duration) (string, error) {
claims := PurposeClaims{
UserID: userID,
Purpose: purpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
func ValidateJWT(tokenStr string, secret []byte) (*Claims, error) { func ValidateJWT(tokenStr string, secret []byte) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if token.Method != jwt.SigningMethodHS256 {
return nil, errors.New("unexpected signing method")
}
return secret, nil return secret, nil
}) })
if err != nil { if err != nil {
@@ -35,7 +60,29 @@ func ValidateJWT(tokenStr string, secret []byte) (*Claims, error) {
claims, ok := token.Claims.(*Claims) claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid { if !ok || !token.Valid {
return nil, err return nil, errors.New("invalid token")
}
return claims, nil
}
func ValidatePurposeJWT(tokenStr, expectedPurpose string, secret []byte) (*PurposeClaims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &PurposeClaims{}, func(token *jwt.Token) (interface{}, error) {
if token.Method != jwt.SigningMethodHS256 {
return nil, errors.New("unexpected signing method")
}
return secret, nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*PurposeClaims)
if !ok || !token.Valid {
return nil, errors.New("invalid token")
}
if claims.Purpose != expectedPurpose {
return nil, errors.New("invalid token purpose")
} }
return claims, nil return claims, nil

View File

@@ -1,5 +0,0 @@
sudo docker buildx build \
--platform linux/amd64,linux/arm64 \
-t git.miaurizius.de/miaurizius/miauinv:latest \
-t git.miaurizius.de/miaurizius/miauinv:v1.0.1 \
--push .

View File

@@ -2,22 +2,63 @@
document.addEventListener("DOMContentLoaded", () => { document.addEventListener("DOMContentLoaded", () => {
const form = document.getElementById("login-form"); const form = document.getElementById("login-form");
const errorBox = document.getElementById("error"); const errorBox = document.getElementById("error");
const usernameInput = document.getElementById("username");
const passwordInput = document.getElementById("password");
const twoFactorInput = document.getElementById("two-factor-code");
const twoFactorGroup = document.getElementById("two-factor-group");
const submitButton = document.getElementById("login-submit");
let pendingTwoFactorToken = null;
if (!form) return; if (!form) return;
function showError(message) {
errorBox.textContent = message || "Login failed.";
errorBox.style.display = "block";
}
function storeTokens(data) {
localStorage.setItem("access_token", data.access_token);
localStorage.setItem("refresh_token", data.refresh_token);
}
function switchToTwoFactorMode(token) {
pendingTwoFactorToken = token;
usernameInput.disabled = true;
passwordInput.disabled = true;
twoFactorGroup.style.display = "block";
twoFactorInput.required = true;
twoFactorInput.focus();
submitButton.textContent = "Verify code";
}
form.addEventListener("submit", async (e) => { form.addEventListener("submit", async (e) => {
e.preventDefault(); e.preventDefault();
errorBox.style.display = "none"; errorBox.style.display = "none";
submitButton.disabled = true;
const username = document.getElementById("username").value;
const password = document.getElementById("password").value;
try { try {
const response = await fetch("/api/login", { let response;
method: "POST",
headers: { "Content-Type": "application/json" }, if (pendingTwoFactorToken) {
body: JSON.stringify({ username, password }) response = await fetch("/api/login/2fa", {
}); method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
two_factor_token: pendingTwoFactorToken,
code: twoFactorInput.value.trim()
})
});
} else {
response = await fetch("/api/login", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
username: usernameInput.value,
password: passwordInput.value
})
});
}
if (!response.ok) { if (!response.ok) {
const text = await response.text(); const text = await response.text();
@@ -26,17 +67,17 @@ document.addEventListener("DOMContentLoaded", () => {
const data = await response.json(); const data = await response.json();
localStorage.setItem("access_token", data.access_token); if (data.requires_2fa) {
localStorage.setItem("refresh_token", data.refresh_token); switchToTwoFactorMode(data.two_factor_token);
return;
document.cookie = `access_token=${data.access_token}; path=/; max-age=900; SameSite=Lax; Secure`; }
document.cookie = `refresh_token=${data.refresh_token}; path=/; max-age=604800; SameSite=Lax; Secure`;
storeTokens(data);
window.location.href = "/dashboard"; window.location.href = "/dashboard";
} catch (err) { } catch (err) {
errorBox.textContent = err.message || "Login failed."; showError(err.message);
errorBox.style.display = "block"; } finally {
submitButton.disabled = false;
} }
}); });
}); });

View File

@@ -1,6 +1,7 @@
package frontend package frontend
import ( import (
"MiauInv/storage"
"html/template" "html/template"
"net/http" "net/http"
"os" "os"
@@ -53,7 +54,28 @@ func Home(w http.ResponseWriter, r *http.Request) {
func Dashboard(w http.ResponseWriter, r *http.Request) { func Dashboard(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html") w.Header().Set("Content-Type", "text/html")
err := dashboard.ExecuteTemplate(w, "base.html", struct {
var itemHive, projectHive, locationHive int
err := storage.DB.QueryRow("SELECT COUNT(*) FROM items").Scan(&itemHive)
if err != nil {
http.Error(w, "Failed to count items", http.StatusInternalServerError)
return
}
err = storage.DB.QueryRow("SELECT COUNT(*) FROM projects").Scan(&projectHive)
if err != nil {
http.Error(w, "Failed to count projects", http.StatusInternalServerError)
return
}
err = storage.DB.QueryRow("SELECT COUNT(*) FROM locations").Scan(&locationHive)
if err != nil {
http.Error(w, "Failed to count locations", http.StatusInternalServerError)
return
}
err = dashboard.ExecuteTemplate(w, "base.html", struct {
Title string Title string
Stats struct { Stats struct {
Items int Items int
@@ -67,9 +89,9 @@ func Dashboard(w http.ResponseWriter, r *http.Request) {
Projects int Projects int
Locations int Locations int
}{ }{
Items: 1, Items: itemHive,
Projects: 1, Projects: projectHive,
Locations: 3, Locations: locationHive,
}, },
}) })
if err != nil { if err != nil {

View File

@@ -24,12 +24,18 @@
<input type="text" id="username" placeholder="Username" autocomplete="username" required> <input type="text" id="username" placeholder="Username" autocomplete="username" required>
</div> </div>
<div class="form-group"> <div class="form-group" id="password-group">
<label for="password" class="sr-only">Password</label> <label for="password" class="sr-only">Password</label>
<input type="password" id="password" placeholder="Password" autocomplete="current-password" required> <input type="password" id="password" placeholder="Password" autocomplete="current-password" required>
</div> </div>
<button type="submit" class="btn btn-primary">Sign In</button> <div class="form-group" id="two-factor-group" style="display: none;">
<label for="two-factor-code" class="sr-only">2FA code</label>
<input type="text" id="two-factor-code" placeholder="Authenticator or recovery code" autocomplete="one-time-code" inputmode="text" pattern="[0-9A-Za-z\- ]*">
<p class="subtitle" style="margin-top: 0.75rem;">Enter your 6-digit authenticator code or one recovery code.</p>
</div>
<button type="submit" id="login-submit" class="btn btn-primary">Sign In</button>
</form> </form>
<div id="error" class="message error"></div> <div id="error" class="message error"></div>

2
go.mod
View File

@@ -6,12 +6,14 @@ require (
github.com/glebarez/go-sqlite v1.22.0 github.com/glebarez/go-sqlite v1.22.0
github.com/golang-jwt/jwt/v5 v5.3.1 github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.5.0 github.com/google/uuid v1.5.0
github.com/pquerna/otp v1.5.0
github.com/tdewolff/minify/v2 v2.24.13 github.com/tdewolff/minify/v2 v2.24.13
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.52.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect

11
go.sum
View File

@@ -1,3 +1,7 @@
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
@@ -10,8 +14,15 @@ github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/tdewolff/minify/v2 v2.24.13 h1:xrcF7gKDnUszseEY9WX9mUlZII2v2Go/QAcAwRASw58= github.com/tdewolff/minify/v2 v2.24.13 h1:xrcF7gKDnUszseEY9WX9mUlZII2v2Go/QAcAwRASw58=
github.com/tdewolff/minify/v2 v2.24.13/go.mod h1:emvwoYeIl8bfAKqRU5ww95LX9Gpggpqv/naal9a8Yq0= github.com/tdewolff/minify/v2 v2.24.13/go.mod h1:emvwoYeIl8bfAKqRU5ww95LX9Gpggpqv/naal9a8Yq0=
github.com/tdewolff/parse/v2 v2.8.12 h1:5BBjfaCv482v3nltlS0u6wH1xJaxjR6ofDrWttNvROg= github.com/tdewolff/parse/v2 v2.8.12 h1:5BBjfaCv482v3nltlS0u6wH1xJaxjR6ofDrWttNvROg=

View File

@@ -2,19 +2,23 @@ package handlers
import ( import (
"MiauInv/auth" "MiauInv/auth"
"MiauInv/config"
"MiauInv/models" "MiauInv/models"
"MiauInv/storage" "MiauInv/storage"
"MiauInv/util" utils "MiauInv/util"
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json" "encoding/json"
"image/png"
"log" "log"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time" "time"
)
var cfg, _ = config.LoadConfig() "github.com/pquerna/otp/totp"
)
func APIRegister(w http.ResponseWriter, r *http.Request) { func APIRegister(w http.ResponseWriter, r *http.Request) {
var user models.User var user models.User
@@ -30,6 +34,12 @@ func APIRegister(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(user.Password) > 72 {
log.Println("POST [api/register] User password too long")
http.Error(w, "Password exceeds the maximum allowed length of 72 characters", http.StatusUnprocessableEntity)
return
}
hashed, err := auth.HashPassword(user.Password) hashed, err := auth.HashPassword(user.Password)
if err != nil { if err != nil {
log.Println("POST [api/register] " + r.RemoteAddr + ": " + err.Error()) log.Println("POST [api/register] " + r.RemoteAddr + ": " + err.Error())
@@ -49,6 +59,7 @@ func APIRegister(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
log.Println("POST [api/register] " + r.RemoteAddr + ": Successfully created user") log.Println("POST [api/register] " + r.RemoteAddr + ": Successfully created user")
} }
func APILogin(w http.ResponseWriter, r *http.Request) { func APILogin(w http.ResponseWriter, r *http.Request) {
var creds struct { var creds struct {
Username string `json:"username"` Username string `json:"username"`
@@ -80,76 +91,300 @@ func APILogin(w http.ResponseWriter, r *http.Request) {
return return
} }
accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret) if user.TwoFactorEnabled {
if err != nil { twoFactorToken, err := auth.GeneratePurposeJWT(user.ID, "2fa_login", secret, 5*time.Minute)
log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) if err != nil {
http.Error(w, "Could not generate token", http.StatusInternalServerError) log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not generate 2FA challenge", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"requires_2fa": true,
"two_factor_token": twoFactorToken,
})
log.Println("POST [api/login] " + r.RemoteAddr + ": Password accepted, 2FA required")
return return
} }
refreshTokenPlain, err := utils.GenerateRefreshToken() issueLoginSession(w, r, user)
if err != nil {
log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "could not generate refresh token", http.StatusInternalServerError)
return
}
refreshHash := utils.HashToken(refreshTokenPlain)
refreshID := utils.GenerateUUID()
refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix() // expiry: 7 days
deviceInfo := r.Header.Get("User-Agent")
if err := storage.AddRefreshToken(&models.RefreshToken{
ID: refreshID,
UserID: user.ID,
Token: refreshHash,
ExpiresAt: refreshExpires,
DeviceInfo: deviceInfo,
CreatedAt: time.Now().Unix(),
Revoked: false,
}); err != nil {
log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "could not save refresh token", http.StatusInternalServerError)
return
}
// Return access + refresh token (refresh in plain for client to store securely)
resp := map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshTokenPlain,
"user": map[string]interface{}{
"id": user.ID,
"username": user.Username,
"role": user.Role,
},
}
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: accessToken,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: refreshTokenPlain,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Something went wrong", http.StatusInternalServerError)
return
}
log.Println("POST [api/login] " + r.RemoteAddr + ": Successfully logged in") log.Println("POST [api/login] " + r.RemoteAddr + ": Successfully logged in")
} }
func APILoginTwoFactor(w http.ResponseWriter, r *http.Request) {
var req struct {
TwoFactorToken string `json:"two_factor_token"`
Code string `json:"code"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
secret := []byte(os.Getenv("JWT_SECRET"))
claims, err := auth.ValidatePurposeJWT(req.TwoFactorToken, "2fa_login", secret)
if err != nil {
log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Invalid or expired 2FA challenge", http.StatusUnauthorized)
return
}
user, err := storage.GetUserById(claims.UserID)
if err != nil || !user.TwoFactorEnabled || user.TwoFactorSecret == "" {
log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": 2FA not available for user")
http.Error(w, "Invalid 2FA state", http.StatusUnauthorized)
return
}
code := strings.TrimSpace(req.Code)
validTOTP := totp.Validate(code, user.TwoFactorSecret)
usedRecoveryCode := false
if !validTOTP {
recoveryCodeHash := utils.HashToken(normalizeRecoveryCode(code))
usedRecoveryCode, err = storage.UseUserRecoveryCode(user.ID, recoveryCodeHash)
if err != nil {
log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not validate recovery code", http.StatusInternalServerError)
return
}
}
if !validTOTP && !usedRecoveryCode {
log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": Invalid 2FA or recovery code")
http.Error(w, "Invalid 2FA or recovery code", http.StatusUnauthorized)
return
}
issueLoginSession(w, r, user)
if usedRecoveryCode {
log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": Successfully logged in with recovery code")
return
}
log.Println("POST [api/login/2fa] " + r.RemoteAddr + ": Successfully logged in with 2FA")
}
func TwoFactorSetup(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(auth.UserContextKey).(*auth.Claims)
user, err := storage.GetUserById(claims.UserID)
if err != nil {
log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "User not found", http.StatusNotFound)
return
}
if user.TwoFactorEnabled {
http.Error(w, "2FA is already enabled", http.StatusConflict)
return
}
key, err := totp.Generate(totp.GenerateOpts{
Issuer: "MiauInv",
AccountName: user.Username,
SecretSize: 20,
})
if err != nil {
log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not generate 2FA secret", http.StatusInternalServerError)
return
}
if err := storage.SetUserTwoFactorSecret(user.ID, key.Secret()); err != nil {
log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not save 2FA secret", http.StatusInternalServerError)
return
}
img, err := key.Image(220, 220)
if err != nil {
log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not generate QR code", http.StatusInternalServerError)
return
}
var qr bytes.Buffer
if err := png.Encode(&qr, img); err != nil {
log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not encode QR code", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"secret": key.Secret(),
"otpauth_url": key.URL(),
"qr_code": "data:image/png;base64," + base64.StdEncoding.EncodeToString(qr.Bytes()),
})
log.Println("POST [api/2fa/setup] " + r.RemoteAddr + ": Created 2FA setup challenge")
}
func TwoFactorEnable(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Code string `json:"code"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
claims := r.Context().Value(auth.UserContextKey).(*auth.Claims)
user, err := storage.GetUserById(claims.UserID)
if err != nil {
log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "User not found", http.StatusNotFound)
return
}
if user.TwoFactorSecret == "" {
http.Error(w, "2FA setup has not been started", http.StatusBadRequest)
return
}
if !totp.Validate(strings.TrimSpace(req.Code), user.TwoFactorSecret) {
http.Error(w, "Invalid 2FA code", http.StatusUnauthorized)
return
}
recoveryCodes, recoveryCodeHashes, err := generateRecoveryCodes(10)
if err != nil {
log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not generate recovery codes", http.StatusInternalServerError)
return
}
if err := storage.EnableUserTwoFactorWithRecoveryCodes(user.ID, recoveryCodeHashes); err != nil {
log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not enable 2FA", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"two_factor_enabled": true,
"recovery_codes": recoveryCodes,
})
log.Println("POST [api/2fa/enable] " + r.RemoteAddr + ": Enabled 2FA and generated recovery codes")
}
func TwoFactorDisable(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Password string `json:"password"`
Code string `json:"code"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
claims := r.Context().Value(auth.UserContextKey).(*auth.Claims)
user, err := storage.GetUserById(claims.UserID)
if err != nil {
log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "User not found", http.StatusNotFound)
return
}
if !auth.CheckPasswordHash(req.Password, user.Password) {
http.Error(w, "Invalid password", http.StatusUnauthorized)
return
}
if user.TwoFactorEnabled && !totp.Validate(strings.TrimSpace(req.Code), user.TwoFactorSecret) {
http.Error(w, "Invalid 2FA code", http.StatusUnauthorized)
return
}
if err := storage.DisableUserTwoFactor(user.ID); err != nil {
log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not disable 2FA", http.StatusInternalServerError)
return
}
if err := storage.RevokeAllRefreshTokensForUser(user.ID); err != nil {
log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not revoke sessions", http.StatusInternalServerError)
return
}
clearAuthCookies(w)
writeJSON(w, http.StatusOK, map[string]interface{}{"two_factor_enabled": false})
log.Println("POST [api/2fa/disable] " + r.RemoteAddr + ": Disabled 2FA")
}
func TwoFactorRegenerateRecoveryCodes(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Password string `json:"password"`
Code string `json:"code"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
claims := r.Context().Value(auth.UserContextKey).(*auth.Claims)
user, err := storage.GetUserById(claims.UserID)
if err != nil {
log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "User not found", http.StatusNotFound)
return
}
if !user.TwoFactorEnabled || user.TwoFactorSecret == "" {
http.Error(w, "2FA is not enabled", http.StatusBadRequest)
return
}
if !auth.CheckPasswordHash(req.Password, user.Password) {
http.Error(w, "Invalid password", http.StatusUnauthorized)
return
}
if !totp.Validate(strings.TrimSpace(req.Code), user.TwoFactorSecret) {
http.Error(w, "Invalid 2FA code", http.StatusUnauthorized)
return
}
recoveryCodes, recoveryCodeHashes, err := generateRecoveryCodes(10)
if err != nil {
log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not generate recovery codes", http.StatusInternalServerError)
return
}
if err := storage.ReplaceUserRecoveryCodes(user.ID, recoveryCodeHashes); err != nil {
log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not save recovery codes", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"recovery_codes": recoveryCodes,
})
log.Println("POST [api/2fa/recovery-codes/regenerate] " + r.RemoteAddr + ": Regenerated recovery codes")
}
func Logout(w http.ResponseWriter, r *http.Request) { func Logout(w http.ResponseWriter, r *http.Request) {
claims := r.Context().Value(auth.UserContextKey).(*auth.Claims) claims := r.Context().Value(auth.UserContextKey).(*auth.Claims)
err := storage.RevokeAllRefreshTokensForUser(claims.UserID) err := storage.RevokeAllRefreshTokensForUser(claims.UserID)
@@ -158,8 +393,10 @@ func Logout(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
return return
} }
w.WriteHeader(204) clearAuthCookies(w)
w.WriteHeader(http.StatusNoContent)
} }
func TestHandler(w http.ResponseWriter, r *http.Request) { func TestHandler(w http.ResponseWriter, r *http.Request) {
claims, _ := utils.IsLoggedIn(w, r) claims, _ := utils.IsLoggedIn(w, r)
@@ -175,13 +412,24 @@ func TestHandler(w http.ResponseWriter, r *http.Request) {
} }
log.Println("GET [api/ping] " + r.RemoteAddr + ": Successfully tested connection") log.Println("GET [api/ping] " + r.RemoteAddr + ": Successfully tested connection")
} }
func RefreshToken(w http.ResponseWriter, r *http.Request) { func RefreshToken(w http.ResponseWriter, r *http.Request) {
var req struct { var req struct {
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) if r.Body != nil {
http.Error(w, "Invalid request", http.StatusBadRequest) _ = json.NewDecoder(r.Body).Decode(&req)
}
if req.RefreshToken == "" {
cookie, err := r.Cookie("refresh_token")
if err == nil {
req.RefreshToken = cookie.Value
}
}
if req.RefreshToken == "" {
log.Println("POST [api/refresh] " + r.RemoteAddr + ": Missing refresh token")
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
return return
} }
@@ -198,43 +446,17 @@ func RefreshToken(w http.ResponseWriter, r *http.Request) {
log.Println(err) log.Println(err)
} }
newToken, _ := utils.GenerateRefreshToken()
newHash := utils.HashToken(newToken)
newExpires := time.Now().Add(7 * 24 * time.Hour).Unix() //7 days
newID := utils.GenerateUUID()
deviceInfo := r.Header.Get("User-Agent")
if err = storage.AddRefreshToken(&models.RefreshToken{
ID: newID,
UserID: tokenRow.UserID,
Token: newHash,
ExpiresAt: newExpires,
CreatedAt: time.Now().Unix(),
Revoked: false,
DeviceInfo: deviceInfo,
}); err != nil {
log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not generate new refresh token", http.StatusInternalServerError)
return
}
user, err := storage.GetUserById(tokenRow.UserID) user, err := storage.GetUserById(tokenRow.UserID)
if err != nil { if err != nil {
log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
return return
} }
accessToken, _ := auth.GenerateJWT(tokenRow.UserID, user.Role, []byte(os.Getenv("JWT_SECRET")))
if err = json.NewEncoder(w).Encode(map[string]string{ issueLoginSession(w, r, user)
"access_token": accessToken,
"refresh_token": newToken,
}); err != nil {
log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
log.Println("POST [api/refresh] " + r.RemoteAddr + ": Successfully refreshed token") log.Println("POST [api/refresh] " + r.RemoteAddr + ": Successfully refreshed token")
} }
func UserInfo(w http.ResponseWriter, r *http.Request) { func UserInfo(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Method " + r.Method + " not allowed") log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Method " + r.Method + " not allowed")
@@ -281,11 +503,21 @@ func UserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "User not found", http.StatusNotFound) http.Error(w, "User not found", http.StatusNotFound)
return return
} }
recoveryCodesRemaining := 0
if user.TwoFactorEnabled {
if count, err := storage.CountUnusedRecoveryCodes(user.ID); err == nil {
recoveryCodesRemaining = count
}
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(map[string]interface{}{ err = json.NewEncoder(w).Encode(map[string]interface{}{
"id": user.ID, "id": user.ID,
"username": user.Username, "username": user.Username,
"avatar_url": "", "avatar_url": "",
"two_factor_enabled": user.TwoFactorEnabled,
"recovery_codes_remaining": recoveryCodesRemaining,
}) })
if err != nil { if err != nil {
log.Println("GET [api/userinfo] " + r.RemoteAddr + ": " + err.Error()) log.Println("GET [api/userinfo] " + r.RemoteAddr + ": " + err.Error())
@@ -293,3 +525,137 @@ func UserInfo(w http.ResponseWriter, r *http.Request) {
} }
log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Successfully retrieved user info of " + user.Username + " (" + user.ID + ")") log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Successfully retrieved user info of " + user.Username + " (" + user.ID + ")")
} }
func issueLoginSession(w http.ResponseWriter, r *http.Request, user models.User) {
secret := []byte(os.Getenv("JWT_SECRET"))
if len(secret) == 0 {
log.Println("AUTH " + r.RemoteAddr + ": Server misconfiguration")
http.Error(w, "Server misconfiguration", http.StatusInternalServerError)
return
}
accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret)
if err != nil {
log.Println("AUTH " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "Could not generate token", http.StatusInternalServerError)
return
}
refreshTokenPlain, err := utils.GenerateRefreshToken()
if err != nil {
log.Println("AUTH " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "could not generate refresh token", http.StatusInternalServerError)
return
}
refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix()
if err := storage.AddRefreshToken(&models.RefreshToken{
ID: utils.GenerateUUID(),
UserID: user.ID,
Token: utils.HashToken(refreshTokenPlain),
ExpiresAt: refreshExpires,
DeviceInfo: r.Header.Get("User-Agent"),
CreatedAt: time.Now().Unix(),
Revoked: false,
}); err != nil {
log.Println("AUTH " + r.RemoteAddr + ": " + err.Error())
http.Error(w, "could not save refresh token", http.StatusInternalServerError)
return
}
setAuthCookies(w, accessToken, refreshTokenPlain)
writeJSON(w, http.StatusOK, map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshTokenPlain,
"user": map[string]interface{}{
"id": user.ID,
"username": user.Username,
"role": user.Role,
"two_factor_enabled": user.TwoFactorEnabled,
},
})
}
func generateRecoveryCodes(count int) ([]string, []string, error) {
codes := make([]string, 0, count)
hashes := make([]string, 0, count)
seen := make(map[string]struct{}, count)
for len(codes) < count {
code, err := generateRecoveryCode()
if err != nil {
return nil, nil, err
}
normalized := normalizeRecoveryCode(code)
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
codes = append(codes, code)
hashes = append(hashes, utils.HashToken(normalized))
}
return codes, hashes, nil
}
func generateRecoveryCode() (string, error) {
bytes := make([]byte, 10)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
raw := hex.EncodeToString(bytes)
return raw[0:5] + "-" + raw[5:10] + "-" + raw[10:15] + "-" + raw[15:20], nil
}
func normalizeRecoveryCode(code string) string {
code = strings.TrimSpace(code)
code = strings.ReplaceAll(code, "-", "")
code = strings.ReplaceAll(code, " ", "")
return strings.ToLower(code)
}
func setAuthCookies(w http.ResponseWriter, accessToken, refreshToken string) {
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: accessToken,
Path: "/",
MaxAge: 15 * 60,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: refreshToken,
Path: "/",
MaxAge: 7 * 24 * 60 * 60,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
}
func clearAuthCookies(w http.ResponseWriter) {
for _, name := range []string{"access_token", "refresh_token"} {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
}
}
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Println("JSON response error: " + err.Error())
}
}

View File

@@ -1,8 +1,10 @@
package models package models
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Role string `json:"role"` Role string `json:"role"`
TwoFactorEnabled bool `json:"two_factor_enabled"`
TwoFactorSecret string `json:"-"`
} }

View File

@@ -82,10 +82,15 @@ func (this *Server) Run() {
// API // API
// //
mux.HandleFunc("/api/login", handlers.APILogin) mux.HandleFunc("/api/login", handlers.APILogin)
mux.HandleFunc("/api/login/2fa", handlers.APILoginTwoFactor)
mux.HandleFunc("/api/refresh", handlers.RefreshToken) mux.HandleFunc("/api/refresh", handlers.RefreshToken)
mux.Handle("/api/logout", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.Logout))) mux.Handle("/api/logout", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.Logout)))
mux.Handle("/api/profile", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.UserInfo))) mux.Handle("/api/profile", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.UserInfo)))
mux.HandleFunc("/api/userinfo", handlers.UserInfo) mux.Handle("/api/2fa/setup", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorSetup)))
mux.Handle("/api/2fa/enable", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorEnable)))
mux.Handle("/api/2fa/disable", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorDisable)))
mux.Handle("/api/2fa/recovery-codes/regenerate", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.TwoFactorRegenerateRecoveryCodes)))
mux.Handle("/api/userinfo", auth.AuthMiddleware(this.JWTSecret)(http.HandlerFunc(handlers.UserInfo)))
if this.AllowRegistration { if this.AllowRegistration {
mux.HandleFunc("/api/register", handlers.APIRegister) mux.HandleFunc("/api/register", handlers.APIRegister)
} }

View File

@@ -2,10 +2,12 @@ package storage
import ( import (
"MiauInv/models" "MiauInv/models"
utils "MiauInv/util"
"database/sql" "database/sql"
"errors" "errors"
"log" "log"
"strings" "strings"
"time"
_ "github.com/glebarez/go-sqlite" _ "github.com/glebarez/go-sqlite"
) )
@@ -27,7 +29,9 @@ func InitDB(filepath string) error {
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE, username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL, password TEXT NOT NULL,
role TEXT NOT NULL role TEXT NOT NULL,
two_factor_enabled INTEGER NOT NULL DEFAULT 0,
two_factor_secret TEXT NOT NULL DEFAULT ''
); );
CREATE TABLE IF NOT EXISTS refresh_tokens ( CREATE TABLE IF NOT EXISTS refresh_tokens (
@@ -41,6 +45,16 @@ func InitDB(filepath string) error {
FOREIGN KEY(user_id) REFERENCES users(id) FOREIGN KEY(user_id) REFERENCES users(id)
); );
CREATE TABLE IF NOT EXISTS two_factor_recovery_codes (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
code_hash TEXT NOT NULL,
created_at INTEGER NOT NULL,
used_at INTEGER DEFAULT NULL,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
UNIQUE(user_id, code_hash)
);
CREATE TABLE IF NOT EXISTS items ( CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -84,7 +98,26 @@ func InitDB(filepath string) error {
log.Fatal(err) log.Fatal(err)
} }
return err if err := ensureUserTwoFactorColumns(); err != nil {
return err
}
return nil
}
func ensureUserTwoFactorColumns() error {
migrations := []string{
"ALTER TABLE users ADD COLUMN two_factor_enabled INTEGER NOT NULL DEFAULT 0",
"ALTER TABLE users ADD COLUMN two_factor_secret TEXT NOT NULL DEFAULT ''",
}
for _, migration := range migrations {
_, err := DB.Exec(migration)
if err != nil && !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
return nil
} }
// Users // Users
@@ -93,18 +126,131 @@ func AddUser(user *models.User) error {
return err return err
} }
func GetUserByUsername(username string) (models.User, error) { func GetUserByUsername(username string) (models.User, error) {
row := DB.QueryRow("SELECT * FROM users WHERE username = ?", strings.ToLower(username)) row := DB.QueryRow(`
var user models.User SELECT id, username, password, role, two_factor_enabled, two_factor_secret
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role) FROM users
return user, err WHERE username = ?
`, strings.ToLower(username))
return scanUser(row)
} }
func GetUserById(id string) (models.User, error) { func GetUserById(id string) (models.User, error) {
row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id) row := DB.QueryRow(`
SELECT id, username, password, role, two_factor_enabled, two_factor_secret
FROM users
WHERE id = ?
`, id)
return scanUser(row)
}
func scanUser(row *sql.Row) (models.User, error) {
var user models.User var user models.User
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role) var twoFactorEnabled int
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role, &twoFactorEnabled, &user.TwoFactorSecret)
user.TwoFactorEnabled = twoFactorEnabled == 1
return user, err return user, err
} }
func SetUserTwoFactorSecret(userID, secret string) error {
_, err := DB.Exec("UPDATE users SET two_factor_secret = ? WHERE id = ?", secret, userID)
return err
}
func EnableUserTwoFactorWithRecoveryCodes(userID string, recoveryCodeHashes []string) error {
tx, err := DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
return err
}
now := time.Now().Unix()
for _, codeHash := range recoveryCodeHashes {
if _, err := tx.Exec(`
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
VALUES (?, ?, ?, ?)
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
return err
}
}
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 1 WHERE id = ?", userID); err != nil {
return err
}
return tx.Commit()
}
func DisableUserTwoFactor(userID string) error {
tx, err := DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 0, two_factor_secret = '' WHERE id = ?", userID); err != nil {
return err
}
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
return err
}
return tx.Commit()
}
func ReplaceUserRecoveryCodes(userID string, recoveryCodeHashes []string) error {
tx, err := DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
return err
}
now := time.Now().Unix()
for _, codeHash := range recoveryCodeHashes {
if _, err := tx.Exec(`
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
VALUES (?, ?, ?, ?)
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
return err
}
}
return tx.Commit()
}
func UseUserRecoveryCode(userID, codeHash string) (bool, error) {
res, err := DB.Exec(`
UPDATE two_factor_recovery_codes
SET used_at = ?
WHERE user_id = ? AND code_hash = ? AND used_at IS NULL
`, time.Now().Unix(), userID, codeHash)
if err != nil {
return false, err
}
n, err := res.RowsAffected()
if err != nil {
return false, err
}
return n == 1, nil
}
func CountUnusedRecoveryCodes(userID string) (int, error) {
var count int
err := DB.QueryRow(`
SELECT COUNT(*)
FROM two_factor_recovery_codes
WHERE user_id = ? AND used_at IS NULL
`, userID).Scan(&count)
return count, err
}
// Refresh Tokens // Refresh Tokens
func AddRefreshToken(token *models.RefreshToken) error { func AddRefreshToken(token *models.RefreshToken) error {
_, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)", _, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",