From ce88b0b432736f94f5ad00c37e77d21a54752ee0 Mon Sep 17 00:00:00 2001 From: farhadh Date: Mon, 11 May 2026 21:09:26 +0200 Subject: [PATCH] refactor(session): store user ID in session instead of full struct Replaces storing the full User object in the session cookie with just the user ID. GetLoginUser now re-fetches the user from the database on every request so credential/permission changes take effect immediately without requiring a re-login. Includes a backward-compatible migration path for existing sessions that still carry the old struct payload. --- web/session/session.go | 71 +++++++++++++++++++++++++++++++++++-- web/session/session_test.go | 47 ++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 web/session/session_test.go diff --git a/web/session/session.go b/web/session/session.go index 6ca43e78..b7340922 100644 --- a/web/session/session.go +++ b/web/session/session.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "github.com/mhsanaei/3x-ui/v3/database" "github.com/mhsanaei/3x-ui/v3/database/model" "github.com/mhsanaei/3x-ui/v3/logger" @@ -27,7 +28,7 @@ func SetLoginUser(c *gin.Context, user *model.User) error { return nil } s := sessions.Default(c) - s.Set(loginUserKey, *user) + s.Set(loginUserKey, user.Id) return s.Save() } @@ -49,7 +50,7 @@ func GetLoginUser(c *gin.Context) *model.User { if obj == nil { return nil } - user, ok := obj.(model.User) + userID, ok := sessionUserID(obj) if !ok { s.Delete(loginUserKey) if err := s.Save(); err != nil { @@ -57,13 +58,77 @@ func GetLoginUser(c *gin.Context) *model.User { } return nil } - return &user + if legacyUserID, ok := legacySessionUserID(obj); ok { + s.Set(loginUserKey, legacyUserID) + if err := s.Save(); err != nil { + logger.Warning("session: failed to migrate legacy user payload:", err) + } + } + user, err := getUserByID(userID) + if err != nil { + logger.Warning("session: failed to load user:", err) + s.Delete(loginUserKey) + if saveErr := s.Save(); saveErr != nil { + logger.Warning("session: failed to drop missing user:", saveErr) + } + return nil + } + return user } func IsLogin(c *gin.Context) bool { return GetLoginUser(c) != nil } +func sessionUserID(obj any) (int, bool) { + switch v := obj.(type) { + case int: + return v, v > 0 + case int64: + return int(v), v > 0 + case int32: + return int(v), v > 0 + case float64: + id := int(v) + return id, v == float64(id) && id > 0 + case model.User: + return v.Id, v.Id > 0 + case *model.User: + if v == nil { + return 0, false + } + return v.Id, v.Id > 0 + default: + return 0, false + } +} + +func legacySessionUserID(obj any) (int, bool) { + switch v := obj.(type) { + case model.User: + return v.Id, v.Id > 0 + case *model.User: + if v == nil { + return 0, false + } + return v.Id, v.Id > 0 + default: + return 0, false + } +} + +func getUserByID(id int) (*model.User, error) { + db := database.GetDB() + if db == nil { + return nil, http.ErrServerClosed + } + user := &model.User{} + if err := db.Model(model.User{}).Where("id = ?", id).First(user).Error; err != nil { + return nil, err + } + return user, nil +} + func ClearSession(c *gin.Context) error { s := sessions.Default(c) s.Clear() diff --git a/web/session/session_test.go b/web/session/session_test.go new file mode 100644 index 00000000..bb48fea2 --- /dev/null +++ b/web/session/session_test.go @@ -0,0 +1,47 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/mhsanaei/3x-ui/v3/database/model" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" +) + +func TestSetLoginUserStoresOnlyUserID(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(sessions.Sessions(sessionCookieName, cookie.NewStore([]byte("01234567890123456789012345678901")))) + router.GET("/", func(c *gin.Context) { + if err := SetLoginUser(c, &model.User{Id: 7, Username: "admin", Password: "hash"}); err != nil { + t.Fatal(err) + } + got := sessions.Default(c).Get(loginUserKey) + if got != 7 { + t.Fatalf("stored session payload = %#v, want user id only", got) + } + c.Status(http.StatusNoContent) + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + router.ServeHTTP(rec, req) + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } +} + +func TestSessionUserIDSupportsLegacyUserPayload(t *testing.T) { + id, ok := sessionUserID(model.User{Id: 11, Username: "admin", Password: "hash"}) + if !ok || id != 11 { + t.Fatalf("legacy session payload resolved to (%d, %v), want (11, true)", id, ok) + } + id, ok = sessionUserID(&model.User{Id: 12, Username: "admin", Password: "hash"}) + if !ok || id != 12 { + t.Fatalf("legacy pointer session payload resolved to (%d, %v), want (12, true)", id, ok) + } +}