From 3c1170bd8702aeacdea04bc61d1da9edc0c2de6e Mon Sep 17 00:00:00 2001 From: Zeni Kim Date: Mon, 21 Oct 2024 20:44:04 -0500 Subject: [PATCH] sessions review --- utils/session.go | 44 ++++++++++++++++++-------------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/utils/session.go b/utils/session.go index c792ba5..27e202d 100644 --- a/utils/session.go +++ b/utils/session.go @@ -5,7 +5,6 @@ package utils import ( - "crypto/md5" "errors" "fmt" "sync" @@ -37,7 +36,7 @@ func (s *SessionUser) Init(c *core.Context) bool { s.context = c payload := make(map[string]interface{}) - + // get cookie usercookie, err := c.GetCookie() if err != nil { @@ -54,13 +53,16 @@ func (s *SessionUser) Init(c *core.Context) bool { payload, err = c.GetJWT().DecodeToken(token) if err != nil { + pass = false + } else { userID := uint(c.CastToInt(payload["userID"])) userAgent := c.GetUserAgent() - hashedCacheKey := CreateAuthTokenHashedCacheKey(userID, userAgent) + // get data from redis + hashedCacheKey := CreateAuthTokenHashedCacheKey(userID, userAgent) cachedToken, err := c.GetCache().Get(hashedCacheKey) if err != nil { @@ -73,28 +75,19 @@ func (s *SessionUser) Init(c *core.Context) bool { if res.Error != nil && !errors.Is(res.Error, gorm.ErrRecordNotFound) { pass = false } - // if have session start the struct if pass { - userAgent := c.GetUserAgent() - sessionKey := fmt.Sprintf("session%v_%v", userID, userAgent) - s.hashedSessionKey = fmt.Sprintf("%v", fmt.Sprintf("%x", md5.Sum([]byte(sessionKey)))) - value, err := c.GetCache().Get(s.hashedSessionKey) + sessionKey := fmt.Sprintf("sess_%v", userAgent) + s.hashedSessionKey = CreateAuthTokenHashedCacheKey(userID, sessionKey) - if err != nil { - s.values = make(map[string]interface{}) - s.authenticated = true - s.userID = userID + s.values = make(map[string]interface{}) + s.authenticated = true + s.userID = userID + value, _ := c.GetCache().Get(s.hashedSessionKey) - if len(value) > 0 { - err := json.Unmarshal([]byte(value), &s.values) - if err != nil { - - } - } else { - - } + if len(value) > 0 { + _ = json.Unmarshal([]byte(value), &s.values) } return true @@ -142,9 +135,9 @@ func (s *SessionUser) Delete(key string) interface{} { func (s *SessionUser) Flush() error { s.mu.Lock() - s.values = make(map[string]interface{}) + s.context.GetCache().Delete(s.hashedSessionKey) s.mu.Unlock() - return s.Save() + return nil } func (s *SessionUser) Save() error { @@ -160,15 +153,14 @@ func (s *SessionUser) Save() error { return err } value = string(buf) - } - s.mu.RUnlock() - if len(value) > 0 { s.context.GetCache().Set(s.hashedSessionKey, value) + } else { + s.context.GetCache().Delete(s.hashedSessionKey) } - + s.mu.RUnlock() return nil }