jiangshuai
2023-11-06 01b0783df9d576027d2393fb427226df4a5d7650
middleware/jwt.go
@@ -1,94 +1,39 @@
package middleware
import (
   "errors"
   "fmt"
   "strings"
   "time"
   "wms/conf"
   jwt "github.com/dgrijalva/jwt-go"
   "github.com/gin-gonic/gin"
   "wms/extend/util"
   "strings"
   "wms/pkg/contextx"
   "wms/pkg/ecode"
)
func validateToken(tokenString string) (util.JSON, error) {
   secretKey := []byte(conf.WebConf.JWTSecret)
   token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
      // Don't forget to validate the alg is what you expect:
      if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
         return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
      }
      return secretKey, nil
   })
   if err != nil {
      return util.JSON{}, err
   }
   if !token.Valid {
      return util.JSON{}, errors.New("invalid token")
   }
   return token.Claims.(jwt.MapClaims), nil
}
// JWTMiddleware parses JWT token from cookie and stores data and expires date to the context
// JWT Token can be passed as cookie, or Authorization header
func JWTMiddleware() gin.HandlerFunc {
func JWTAuth() gin.HandlerFunc {
   return func(c *gin.Context) {
      tokenString, err := c.Cookie("token")
      // failed to read cookie
      if err != nil {
         // try reading HTTP Header
         authorization := c.Request.Header.Get("Authorization")
         if authorization == "" {
            c.Next()
            return
         }
         sp := strings.Split(authorization, "Bearer ")
         // invalid token
         if len(sp) < 1 {
            c.Next()
            return
         }
         tokenString = sp[1]
      ctx := new(contextx.Context).SetCtx(c)
      // 我们这里jwt鉴权取头部信息 Authorization 登录时回返回token信息 这里前端需要把token存储到cookie或者本地localStorage中 不过需要跟后端协商过期时间 可以约定刷新令牌或者重新登录
      token := c.Request.Header.Get("Authorization")
      if token == "" {
         ctx.Fail(ecode.JWTEmpty)
         c.Abort()
         return
      }
      tokenData, err := validateToken(tokenString)
      slices := strings.Split(token, " ")
      if len(slices) == 2 {
         token = slices[1]
      }
      j := NewJWT()
      // parseToken 解析token包含的信息
      claims, err := j.ParseToken(token)
      if err != nil {
         fmt.Println(err.Error())
         if err == TokenExpired {
            c.Next()
            return
         }
         c.Next()
         return
      }
      userParentId := tokenData["parentId"].(string)
      if userParentId == conf.WebConf.NodeId {
         c.Set("parentId", userParentId)
      } else {
         c.Next()
         return
      }
      c.Set("token_expire", tokenData["exp"])
      c.Set("claims", claims)
      c.Next()
   }
}
func GenerateToken(data interface{}) (string, error) {
   //  token is valid for 1 hour
   date := time.Now().Add(time.Hour * 12)
   token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
      "user": data,
      "exp":  date.Unix(),
   })
   secretKey := []byte(conf.WebConf.JWTSecret)
   tokenString, err := token.SignedString(secretKey)
   return tokenString, err
}