package auth import ( "context" "fmt" "net/http" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" ) type contextKey string const ( UserKey contextKey = "user" ServiceKey contextKey = "service" ) type Middleware struct { jwtSecret string serviceTokens []string } func NewMiddleware(jwtSecret string, serviceTokens []string) *Middleware { return &Middleware{ jwtSecret: jwtSecret, serviceTokens: serviceTokens, } } // AuthRequired checks for either JWT token or service token func (m *Middleware) AuthRequired(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Authorization header required", http.StatusUnauthorized) return } // Parse Bearer token parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { http.Error(w, "Invalid authorization format. Use Bearer token", http.StatusUnauthorized) return } token := parts[1] // Try service token first (faster check) if m.isValidServiceToken(token) { ctx := context.WithValue(r.Context(), ServiceKey, true) next.ServeHTTP(w, r.WithContext(ctx)) return } // Try JWT token claims, err := m.validateJWT(token) if err != nil { log.Warn().Err(err).Msg("Invalid JWT token") http.Error(w, "Invalid token", http.StatusUnauthorized) return } // Add user info to context ctx := context.WithValue(r.Context(), UserKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } // ServiceTokenRequired checks for valid service token only (for internal services) func (m *Middleware) ServiceTokenRequired(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Service authorization required", http.StatusUnauthorized) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { http.Error(w, "Invalid authorization format", http.StatusUnauthorized) return } if !m.isValidServiceToken(parts[1]) { http.Error(w, "Invalid service token", http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), ServiceKey, true) next.ServeHTTP(w, r.WithContext(ctx)) }) } // AdminRequired checks for JWT token with admin permissions func (m *Middleware) AdminRequired(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Admin authorization required", http.StatusUnauthorized) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { http.Error(w, "Invalid authorization format", http.StatusUnauthorized) return } token := parts[1] // Service tokens have admin privileges if m.isValidServiceToken(token) { ctx := context.WithValue(r.Context(), ServiceKey, true) next.ServeHTTP(w, r.WithContext(ctx)) return } // Check JWT for admin role claims, err := m.validateJWT(token) if err != nil { log.Warn().Err(err).Msg("Invalid JWT token for admin access") http.Error(w, "Invalid admin token", http.StatusUnauthorized) return } // Check if user has admin role if role, ok := claims["role"].(string); !ok || role != "admin" { http.Error(w, "Admin privileges required", http.StatusForbidden) return } ctx := context.WithValue(r.Context(), UserKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } func (m *Middleware) isValidServiceToken(token string) bool { for _, serviceToken := range m.serviceTokens { if serviceToken == token { return true } } return false } func (m *Middleware) validateJWT(tokenString string) (jwt.MapClaims, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Validate signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(m.jwtSecret), nil }) if err != nil { return nil, err } if !token.Valid { return nil, fmt.Errorf("invalid token") } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, fmt.Errorf("invalid claims") } // Check expiration if exp, ok := claims["exp"].(float64); ok { if time.Unix(int64(exp), 0).Before(time.Now()) { return nil, fmt.Errorf("token expired") } } return claims, nil } // GetUserFromContext retrieves user claims from request context func GetUserFromContext(ctx context.Context) (jwt.MapClaims, bool) { claims, ok := ctx.Value(UserKey).(jwt.MapClaims) return claims, ok } // IsServiceRequest checks if request is from a service token func IsServiceRequest(ctx context.Context) bool { service, ok := ctx.Value(ServiceKey).(bool) return ok && service }