| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 | package middlewareimport (	"fmt"	"io"	"math/rand"	"net"	"net/http"	"net/http/httputil"	"net/url"	"regexp"	"strings"	"sync"	"sync/atomic"	"time"	"github.com/labstack/echo")// TODO: Handle TLS proxytype (	// ProxyConfig defines the config for Proxy middleware.	ProxyConfig struct {		// Skipper defines a function to skip middleware.		Skipper Skipper		// Balancer defines a load balancing technique.		// Required.		Balancer ProxyBalancer		// Rewrite defines URL path rewrite rules. The values captured in asterisk can be		// retrieved by index e.g. $1, $2 and so on.		// Examples:		// "/old":              "/new",		// "/api/*":            "/$1",		// "/js/*":             "/public/javascripts/$1",		// "/users/*/orders/*": "/user/$1/order/$2",		Rewrite map[string]string		rewriteRegex map[*regexp.Regexp]string	}	// ProxyTarget defines the upstream target.	ProxyTarget struct {		Name string		URL  *url.URL	}	// ProxyBalancer defines an interface to implement a load balancing technique.	ProxyBalancer interface {		AddTarget(*ProxyTarget) bool		RemoveTarget(string) bool		Next() *ProxyTarget	}	commonBalancer struct {		targets []*ProxyTarget		mutex   sync.RWMutex	}	// RandomBalancer implements a random load balancing technique.	randomBalancer struct {		*commonBalancer		random *rand.Rand	}	// RoundRobinBalancer implements a round-robin load balancing technique.	roundRobinBalancer struct {		*commonBalancer		i uint32	})var (	// DefaultProxyConfig is the default Proxy middleware config.	DefaultProxyConfig = ProxyConfig{		Skipper: DefaultSkipper,	})func proxyHTTP(t *ProxyTarget) http.Handler {	return httputil.NewSingleHostReverseProxy(t.URL)}func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {		in, _, err := c.Response().Hijack()		if err != nil {			c.Error(fmt.Errorf("proxy raw, hijack error=%v, url=%s", t.URL, err))			return		}		defer in.Close()		out, err := net.Dial("tcp", t.URL.Host)		if err != nil {			he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err))			c.Error(he)			return		}		defer out.Close()		// Write header		err = r.Write(out)		if err != nil {			he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err))			c.Error(he)			return		}		errCh := make(chan error, 2)		cp := func(dst io.Writer, src io.Reader) {			_, err = io.Copy(dst, src)			errCh <- err		}		go cp(out, in)		go cp(in, out)		err = <-errCh		if err != nil && err != io.EOF {			c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)		}	})}// NewRandomBalancer returns a random proxy balancer.func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {	b := &randomBalancer{commonBalancer: new(commonBalancer)}	b.targets = targets	return b}// NewRoundRobinBalancer returns a round-robin proxy balancer.func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {	b := &roundRobinBalancer{commonBalancer: new(commonBalancer)}	b.targets = targets	return b}// AddTarget adds an upstream target to the list.func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {	for _, t := range b.targets {		if t.Name == target.Name {			return false		}	}	b.mutex.Lock()	defer b.mutex.Unlock()	b.targets = append(b.targets, target)	return true}// RemoveTarget removes an upstream target from the list.func (b *commonBalancer) RemoveTarget(name string) bool {	b.mutex.Lock()	defer b.mutex.Unlock()	for i, t := range b.targets {		if t.Name == name {			b.targets = append(b.targets[:i], b.targets[i+1:]...)			return true		}	}	return false}// Next randomly returns an upstream target.func (b *randomBalancer) Next() *ProxyTarget {	if b.random == nil {		b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))	}	b.mutex.RLock()	defer b.mutex.RUnlock()	return b.targets[b.random.Intn(len(b.targets))]}// Next returns an upstream target using round-robin technique.func (b *roundRobinBalancer) Next() *ProxyTarget {	b.i = b.i % uint32(len(b.targets))	t := b.targets[b.i]	atomic.AddUint32(&b.i, 1)	return t}// Proxy returns a Proxy middleware.//// Proxy middleware forwards the request to upstream server using a configured load balancing technique.func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {	c := DefaultProxyConfig	c.Balancer = balancer	return ProxyWithConfig(c)}// ProxyWithConfig returns a Proxy middleware with config.// See: `Proxy()`func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {	// Defaults	if config.Skipper == nil {		config.Skipper = DefaultLoggerConfig.Skipper	}	if config.Balancer == nil {		panic("echo: proxy middleware requires balancer")	}	config.rewriteRegex = map[*regexp.Regexp]string{}	// Initialize	for k, v := range config.Rewrite {		k = strings.Replace(k, "*", "(\\S*)", -1)		config.rewriteRegex[regexp.MustCompile(k)] = v	}	return func(next echo.HandlerFunc) echo.HandlerFunc {		return func(c echo.Context) (err error) {			if config.Skipper(c) {				return next(c)			}			req := c.Request()			res := c.Response()			tgt := config.Balancer.Next()			// Rewrite			for k, v := range config.rewriteRegex {				replacer := captureTokens(k, req.URL.Path)				if replacer != nil {					req.URL.Path = replacer.Replace(v)				}			}			// Fix header			if req.Header.Get(echo.HeaderXRealIP) == "" {				req.Header.Set(echo.HeaderXRealIP, c.RealIP())			}			if req.Header.Get(echo.HeaderXForwardedProto) == "" {				req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())			}			if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.				req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())			}			// Proxy			switch {			case c.IsWebSocket():				proxyRaw(tgt, c).ServeHTTP(res, req)			case req.Header.Get(echo.HeaderAccept) == "text/event-stream":			default:				proxyHTTP(tgt).ServeHTTP(res, req)			}			return		}	}}
 |