package httpproxy import ( "bytes" "crypto/tls" "fmt" "io/ioutil" "log" "net/http" "net/http/httputil" "net/url" "os" "os/signal" "strings" "sync" "syscall" "time" "golang.org/x/crypto/acme/autocert" ) // ProxyOption is the definition for a Proxy Option func. type ProxyOption func(p *Proxy) // WithLogger adds a logger to the proxy. func WithLogger(logger *log.Logger) ProxyOption { return func(p *Proxy) { p.logger = logger } } // WithAllowedHosts sets the allowed hosts. func WithAllowedHosts(hosts []string) ProxyOption { return func(p *Proxy) { p.allowedHosts = hosts } } // WithErrorServerHeader is the Server Header to use when the communication // between the proxy and the matched service fails in some way. func WithErrorServerHeader(serverHeader []string) ProxyOption { return func(p *Proxy) { p.errorServerHeader = serverHeader } } // WithErrorBody is the response body to use when the communication between the // proxy and the matched service fails in some way. func WithErrorBody(body []byte) ProxyOption { return func(p *Proxy) { p.errorBody = body } } // WithHTTPSRedirect redirects insecure HTTP calls to HTTPS. func WithHTTPSRedirect() ProxyOption { return func(p *Proxy) { p.httpsRedirect = true } } // Proxy is the main proxy struct. type Proxy struct { logger *log.Logger allowedHosts []string handler proxyHandler ruleMutex sync.RWMutex certMan *autocert.Manager errorServerHeader []string errorBody []byte httpsRedirect bool // Metrics VisitsPerServiceAndPath map[string]map[string]int StatusPerServiceAndPath map[string]map[string]map[int]int TotalVisits int TotalErrors int } // NewProxy sets up everything needed to get a running proxy. // TODO: Make Proxy less chatty? func NewProxy(opts ...ProxyOption) *Proxy { p := Proxy{ logger: nil, allowedHosts: []string{}, errorServerHeader: []string{"httpproxy"}, errorBody: []byte("error communicating with matched service"), VisitsPerServiceAndPath: make(map[string]map[string]int), StatusPerServiceAndPath: make(map[string]map[string]map[int]int), } for _, opt := range opts { opt(&p) } p.handler.logger = p.logger p.handler.ruleMutex = &p.ruleMutex p.handler.proxy = &p p.certMan = &autocert.Manager{ Cache: autocert.DirCache("certs"), Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(p.allowedHosts...), } return &p } // TODO: Store metrics on disk func (p *Proxy) incVisitMetric(service, path string) { if _, ok := p.VisitsPerServiceAndPath[service]; !ok { p.VisitsPerServiceAndPath[service] = make(map[string]int) } p.VisitsPerServiceAndPath[service][path]++ p.TotalVisits++ } // TODO: Store metrics on disk func (p *Proxy) incStatusMetric(service, path string, status int) { if _, ok := p.StatusPerServiceAndPath[service]; !ok { p.StatusPerServiceAndPath[service] = make(map[string]map[int]int) } if _, ok := p.StatusPerServiceAndPath[service][path]; !ok { p.StatusPerServiceAndPath[service][path] = make(map[int]int) } p.StatusPerServiceAndPath[service][path][status]++ if status >= 400 { p.TotalErrors++ } } // SetAllowedHosts sets the allowed hosts. func (p *Proxy) SetAllowedHosts(hosts []string) { p.allowedHosts = hosts p.certMan.HostPolicy = autocert.HostWhitelist(p.allowedHosts...) } // AddRule adds a new rule to the proxy engine. func (p *Proxy) AddRule(name string, match string, destinationPort int) error { p.ruleMutex.Lock() defer p.ruleMutex.Unlock() remoteURL, err := url.Parse( fmt.Sprintf("http://127.0.0.1:%d/", destinationPort), ) if err != nil { return fmt.Errorf("could not parse target url: %w", err) } reverseProxy := httputil.NewSingleHostReverseProxy(remoteURL) reverseProxy.Transport = &proxyTransport{ logger: p.logger, proxy: p, errorServerHeader: p.errorServerHeader, errorBody: p.errorBody, } reverseProxy.Transport = http.DefaultTransport p.handler.Rules = append(p.handler.Rules, rule{ Name: name, Match: match, Destination: destinationPort, Proxy: reverseProxy, }) return nil } // ClearRules clears all rules from the proxy. func (p *Proxy) ClearRules() { p.ruleMutex.Lock() defer p.ruleMutex.Unlock() p.handler.Rules = []rule{} } // ListenAndServe starts the proxy. func (p *Proxy) ListenAndServe() { serverFailed := make(chan interface{}) tlsServerFailed := make(chan interface{}) go (func() { defer close(serverFailed) var handler http.Handler if p.httpsRedirect { handler = p.certMan.HTTPHandler(nil) } else { handler = p.handler } server := &http.Server{ Addr: ":8000", Handler: handler, ReadTimeout: 90 * time.Second, WriteTimeout: 90 * time.Second, MaxHeaderBytes: 1 << 20, } err := server.ListenAndServe() if err != nil { p.logger.Fatal("could not start server,", err) } })() go (func() { defer close(tlsServerFailed) server := &http.Server{ Addr: ":8443", Handler: p.handler, TLSConfig: &tls.Config{ GetCertificate: p.certMan.GetCertificate, }, ReadTimeout: 90 * time.Second, WriteTimeout: 90 * time.Second, MaxHeaderBytes: 1 << 20, } err := server.ListenAndServeTLS("", "") if err != nil { p.logger.Fatal("could not start tls server,", err) } })() go (func() { server := &http.Server{ Addr: ":9090", Handler: metricsHandler{ proxy: p, rules: p.handler.Rules, }, ReadTimeout: 90 * time.Second, WriteTimeout: 90 * time.Second, MaxHeaderBytes: 1 << 20, } err := server.ListenAndServe() if err != nil { p.logger.Fatal("could not serve metrics,", err) } })() stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) select { case <-stop: case <-serverFailed: case <-tlsServerFailed: } } type rule struct { Name string Match string Destination int Proxy *httputil.ReverseProxy } type proxyTransport struct { http.RoundTripper logger *log.Logger proxy *Proxy errorServerHeader []string errorBody []byte } func (t *proxyTransport) RoundTrip(r *http.Request) (*http.Response, error) { resp, err := t.RoundTripper.RoundTrip(r) if err != nil { // t.logger.Printf("error when talking to service: %s", err.Error()) resp = &http.Response{ Status: "500 INTERNAL SERVER ERROR", StatusCode: 500, Proto: r.Proto, ProtoMajor: r.ProtoMajor, ProtoMinor: r.ProtoMinor, Header: http.Header{ "Server": t.errorServerHeader, }, Body: ioutil.NopCloser(bytes.NewBuffer(t.errorBody)), ContentLength: 0, TransferEncoding: r.TransferEncoding, Close: true, Uncompressed: false, Trailer: http.Header{}, Request: nil, TLS: r.TLS, } } t.proxy.incStatusMetric("valkyr", r.URL.String(), resp.StatusCode) return resp, nil } type proxyHandler struct { logger *log.Logger ruleMutex *sync.RWMutex proxy *Proxy Rules []rule } // TODO: Cleanup matching logic. func (h proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.ruleMutex.RLock() defer h.ruleMutex.RUnlock() reqHost := strings.Split(r.Host, ":")[0] pathParts := []string{"/"} if r.URL.Path != "/" { pathParts = strings.Split(r.URL.Path, "/") pathParts[0] = "/" } longest := -1 var matched rule for _, rule := range h.Rules { ruleHostPath := strings.Split(rule.Match, "/") ruleHost := ruleHostPath[0] ruleHostPath[0] = "/" if ruleHost != reqHost { continue } if len(ruleHostPath) <= len(pathParts) { for t := range ruleHostPath { if ruleHostPath[t] != pathParts[t] { break } else { if t > longest { longest = t matched = rule } } } } } if longest > -1 { parts := strings.Split(r.URL.Path, "/") if len(parts) == longest+1 { /*h.logger.Printf( "redirecting to %s:%s, missing root", matched.Name, r.URL.String()+"/", )*/ h.proxy.incStatusMetric(matched.Name, r.URL.String(), http.StatusMovedPermanently) http.Redirect(w, r, r.URL.String()+"/", http.StatusMovedPermanently) return } path := "/" + strings.Join(parts[longest+1:], "/") r.URL.Path = path h.proxy.incVisitMetric(matched.Name, r.URL.String()) // h.logger.Printf("proxying to %s:%s", matched.Name, path) matched.Proxy.ServeHTTP(w, r) return } h.proxy.incVisitMetric("valkyr", r.URL.String()) h.proxy.incStatusMetric("valkyr", r.URL.String(), http.StatusNotFound) http.Error(w, "404 NOT FOUND", http.StatusNotFound) } type metricsHandler struct { proxy *Proxy rules []rule } func (h metricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/metrics" { w.WriteHeader(http.StatusOK) w.Write([]byte("services:\n")) services := make([]string, 0, 10) services = append(services, "valkyr") for _, r := range h.rules { services = append(services, r.Name) } for _, service := range services { w.Write([]byte(fmt.Sprintf("\t%s\n", service))) for path, count := range h.proxy.VisitsPerServiceAndPath[service] { w.Write([]byte(fmt.Sprintf("\t\t%s: %d\n", path, count))) for statusCode, statusCount := range h.proxy.StatusPerServiceAndPath[service][path] { w.Write([]byte(fmt.Sprintf("\t\t\t%d: %d\n", statusCode, statusCount))) } } } w.Write([]byte("---\n")) w.Write([]byte(fmt.Sprintf("total_visits: %d\n", h.proxy.TotalVisits))) w.Write([]byte(fmt.Sprintf("total_errors: %d\n", h.proxy.TotalErrors))) } else if r.URL.String() == "/healthcheck" { w.WriteHeader(http.StatusOK) w.Write([]byte("OK\n")) return } }