diff --git a/handlers/account.go b/handlers/account.go index 9f9aa83..89323e3 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/json" "net/http" + "os" "shap-planner-backend/auth" "shap-planner-backend/models" "shap-planner-backend/storage" @@ -11,16 +12,30 @@ import ( func Register(w http.ResponseWriter, r *http.Request) { var user models.User - _ = json.NewDecoder(r.Body).Decode(&user) - hashed, _ := auth.HashPassword(user.Password) - user.Password = hashed - user.ID = utils.GenerateUUID() - - err := storage.AddUser(user) - if err != nil { - http.Error(w, "User exists", http.StatusBadRequest) + if err := json.NewDecoder(r.Body).Decode(&user); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) return } + + if user.Username == "" || user.Password == "" { + http.Error(w, "username and password required", http.StatusBadRequest) + return + } + + hashed, err := auth.HashPassword(user.Password) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + user.Password = hashed + user.ID = utils.GenerateUUID() + user.Role = "user" + + if err := storage.AddUser(user); err != nil { + http.Error(w, "user already exists", http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusCreated) } @@ -29,23 +44,68 @@ func Login(w http.ResponseWriter, r *http.Request) { Username string `json:"username"` Password string `json:"password"` } - _ = json.NewDecoder(r.Body).Decode(&creds) + if err := json.NewDecoder(r.Body).Decode(&creds); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } user, err := storage.GetUserByUsername(creds.Username) if err != nil { - http.Error(w, "User not found", http.StatusUnauthorized) + http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } if !auth.CheckPasswordHash(creds.Password, user.Password) { - http.Error(w, "Wrong password", http.StatusUnauthorized) + http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } - // TODO: JWT oder Session-Token erzeugen - w.WriteHeader(http.StatusOK) - err = json.NewEncoder(w).Encode(user) - if err != nil { + secret := []byte(os.Getenv("SHAP_JWT_SECRET")) + if len(secret) == 0 { + http.Error(w, "Server misconfiguration", http.StatusInternalServerError) return } + + token, err := auth.GenerateJWT(user.ID, user.Role, secret) + if err != nil { + http.Error(w, "Could not generate token", http.StatusInternalServerError) + return + } + + type userResp struct { + ID string `json:"id"` + Username string `json:"username"` + Role string `json:"role"` + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "token": token, + "user": userResp{ + ID: user.ID, + Username: user.Username, + Role: user.Role, + }, + }) +} + +func TestHandler(w http.ResponseWriter, r *http.Request) { + claimsRaw := r.Context().Value(auth.UserContextKey) + if claimsRaw == nil { + http.Error(w, "No claims in context", http.StatusUnauthorized) + return + } + + claims, ok := claimsRaw.(*auth.Claims) + if !ok { + http.Error(w, "Invalid claims", http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "user_id": claims.UserID, + "role": claims.Role, + "msg": "access granted to protected endpoint", + }) } diff --git a/models/models.go b/models/dbmodels.go similarity index 91% rename from models/models.go rename to models/dbmodels.go index d334943..72c5e30 100644 --- a/models/models.go +++ b/models/dbmodels.go @@ -4,6 +4,7 @@ type User struct { ID string `json:"id"` Username string `json:"username"` Password string `json:"password"` + Role string `json:"role"` } type Expense struct { diff --git a/models/loginmodels.go b/models/loginmodels.go new file mode 100644 index 0000000..8130760 --- /dev/null +++ b/models/loginmodels.go @@ -0,0 +1,10 @@ +package models + +import "time" + +type RefreshToken struct { + ID string `json:id` + UserID string `json:userid` + Token string `json:token` + ExpiresAt time.Time `json:expiresat` +} diff --git a/server/server.go b/server/server.go index d1bff77..553ffe5 100644 --- a/server/server.go +++ b/server/server.go @@ -45,13 +45,16 @@ func InitServer() *Server { func (server *Server) Run() { mux := http.NewServeMux() - mux.HandleFunc("/login", handlers.Login) + // Public + mux.HandleFunc("/api/login", handlers.Login) + mux.HandleFunc("/api/register", handlers.Register) - protected := auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.GetExpenses)) - mux.Handle("/expenses", protected) + // Login required + mux.Handle("/api/expenses", auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.GetExpenses))) + mux.Handle("/api/ping", auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.TestHandler))) - adminOnly := auth.AuthMiddleware(server.JWTSecret)(auth.RequireRole("admin")(http.HandlerFunc(handlers.AdminPanel))) - mux.Handle("/admin", adminOnly) + // Admin-only + mux.Handle("/api/admin", auth.AuthMiddleware(server.JWTSecret)(auth.RequireRole("admin")(http.HandlerFunc(handlers.AdminPanel)))) log.Printf("Listening on port %s", server.Port) log.Fatal(http.ListenAndServe(":"+server.Port, mux)) diff --git a/storage/storage.go b/storage/storage.go index 3fd0454..ab6228e 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -2,8 +2,9 @@ package storage import ( "database/sql" - _ "github.com/glebarez/go-sqlite" "shap-planner-backend/models" + + _ "github.com/glebarez/go-sqlite" ) var DB *sql.DB @@ -34,7 +35,7 @@ func InitDB(filepath string) error { } func AddUser(user models.User) error { - _, err := DB.Exec("INSERT INTO users(id, username, password) VALUES (?, ?, ?)", user.ID, user.Username, user.Password) + _, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, user.Username, user.Password, user.Role) return err }