forked from ebhomengo/niki
130 lines
3.3 KiB
Go
130 lines
3.3 KiB
Go
|
// SPDX-License-Identifier: MIT
|
||
|
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
|
||
|
|
||
|
package middleware
|
||
|
|
||
|
import (
|
||
|
"strings"
|
||
|
|
||
|
"github.com/labstack/echo/v4"
|
||
|
)
|
||
|
|
||
|
// TrailingSlashConfig defines the config for TrailingSlash middleware.
|
||
|
type TrailingSlashConfig struct {
|
||
|
// Skipper defines a function to skip middleware.
|
||
|
Skipper Skipper
|
||
|
|
||
|
// Status code to be used when redirecting the request.
|
||
|
// Optional, but when provided the request is redirected using this code.
|
||
|
RedirectCode int `yaml:"redirect_code"`
|
||
|
}
|
||
|
|
||
|
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
|
||
|
var DefaultTrailingSlashConfig = TrailingSlashConfig{
|
||
|
Skipper: DefaultSkipper,
|
||
|
}
|
||
|
|
||
|
// AddTrailingSlash returns a root level (before router) middleware which adds a
|
||
|
// trailing slash to the request `URL#Path`.
|
||
|
//
|
||
|
// Usage `Echo#Pre(AddTrailingSlash())`
|
||
|
func AddTrailingSlash() echo.MiddlewareFunc {
|
||
|
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig)
|
||
|
}
|
||
|
|
||
|
// AddTrailingSlashWithConfig returns an AddTrailingSlash middleware with config.
|
||
|
// See `AddTrailingSlash()`.
|
||
|
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
||
|
// Defaults
|
||
|
if config.Skipper == nil {
|
||
|
config.Skipper = DefaultTrailingSlashConfig.Skipper
|
||
|
}
|
||
|
|
||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||
|
return func(c echo.Context) error {
|
||
|
if config.Skipper(c) {
|
||
|
return next(c)
|
||
|
}
|
||
|
|
||
|
req := c.Request()
|
||
|
url := req.URL
|
||
|
path := url.Path
|
||
|
qs := c.QueryString()
|
||
|
if !strings.HasSuffix(path, "/") {
|
||
|
path += "/"
|
||
|
uri := path
|
||
|
if qs != "" {
|
||
|
uri += "?" + qs
|
||
|
}
|
||
|
|
||
|
// Redirect
|
||
|
if config.RedirectCode != 0 {
|
||
|
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
|
||
|
}
|
||
|
|
||
|
// Forward
|
||
|
req.RequestURI = uri
|
||
|
url.Path = path
|
||
|
}
|
||
|
return next(c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// RemoveTrailingSlash returns a root level (before router) middleware which removes
|
||
|
// a trailing slash from the request URI.
|
||
|
//
|
||
|
// Usage `Echo#Pre(RemoveTrailingSlash())`
|
||
|
func RemoveTrailingSlash() echo.MiddlewareFunc {
|
||
|
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
|
||
|
}
|
||
|
|
||
|
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config.
|
||
|
// See `RemoveTrailingSlash()`.
|
||
|
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
||
|
// Defaults
|
||
|
if config.Skipper == nil {
|
||
|
config.Skipper = DefaultTrailingSlashConfig.Skipper
|
||
|
}
|
||
|
|
||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||
|
return func(c echo.Context) error {
|
||
|
if config.Skipper(c) {
|
||
|
return next(c)
|
||
|
}
|
||
|
|
||
|
req := c.Request()
|
||
|
url := req.URL
|
||
|
path := url.Path
|
||
|
qs := c.QueryString()
|
||
|
l := len(path) - 1
|
||
|
if l > 0 && strings.HasSuffix(path, "/") {
|
||
|
path = path[:l]
|
||
|
uri := path
|
||
|
if qs != "" {
|
||
|
uri += "?" + qs
|
||
|
}
|
||
|
|
||
|
// Redirect
|
||
|
if config.RedirectCode != 0 {
|
||
|
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
|
||
|
}
|
||
|
|
||
|
// Forward
|
||
|
req.RequestURI = uri
|
||
|
url.Path = path
|
||
|
}
|
||
|
return next(c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func sanitizeURI(uri string) string {
|
||
|
// double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
|
||
|
// we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
|
||
|
if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
|
||
|
uri = "/" + strings.TrimLeft(uri, `/\`)
|
||
|
}
|
||
|
return uri
|
||
|
}
|