package httpproxy

import (
	"bytes"
	"crypto/tls"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"net/http/httputil"
	"net/url"
	"os"
	"os/signal"
	"strings"
	"sync"
	"syscall"
	"time"

	"golang.org/x/crypto/acme/autocert"
)

// ProxyOption is the definition for a Proxy Option func.
type ProxyOption func(p *Proxy)

// WithLogger adds a logger to the proxy.
func WithLogger(logger *log.Logger) ProxyOption {
	return func(p *Proxy) {
		p.logger = logger
	}
}

// WithAllowedHosts sets the allowed hosts.
func WithAllowedHosts(hosts []string) ProxyOption {
	return func(p *Proxy) {
		p.allowedHosts = hosts
	}
}

// WithErrorServerHeader is the Server Header to use when the communication
// between the proxy and the matched service fails in some way.
func WithErrorServerHeader(serverHeader []string) ProxyOption {
	return func(p *Proxy) {
		p.errorServerHeader = serverHeader
	}
}

// WithErrorBody is the response body to use when the communication between the
// proxy and the matched service fails in some way.
func WithErrorBody(body []byte) ProxyOption {
	return func(p *Proxy) {
		p.errorBody = body
	}
}

// WithHTTPSRedirect redirects insecure HTTP calls to HTTPS.
func WithHTTPSRedirect() ProxyOption {
	return func(p *Proxy) {
		p.httpsRedirect = true
	}
}

// Proxy is the main proxy struct.
type Proxy struct {
	logger            *log.Logger
	allowedHosts      []string
	handler           proxyHandler
	ruleMutex         sync.RWMutex
	certMan           *autocert.Manager
	errorServerHeader []string
	errorBody         []byte
	httpsRedirect     bool

	// Metrics
	VisitsPerServiceAndPath map[string]map[string]int
	StatusPerServiceAndPath map[string]map[string]map[int]int
	TotalVisits             int
	TotalErrors             int
}

// NewProxy sets up everything needed to get a running proxy.
// TODO: Make Proxy less chatty?
func NewProxy(opts ...ProxyOption) *Proxy {
	p := Proxy{
		logger:            nil,
		allowedHosts:      []string{},
		errorServerHeader: []string{"httpproxy"},
		errorBody:         []byte("error communicating with matched service"),

		VisitsPerServiceAndPath: make(map[string]map[string]int),
		StatusPerServiceAndPath: make(map[string]map[string]map[int]int),
	}

	for _, opt := range opts {
		opt(&p)
	}

	p.handler.logger = p.logger
	p.handler.ruleMutex = &p.ruleMutex
	p.handler.proxy = &p
	p.certMan = &autocert.Manager{
		Cache:      autocert.DirCache("certs"),
		Prompt:     autocert.AcceptTOS,
		HostPolicy: autocert.HostWhitelist(p.allowedHosts...),
	}

	return &p
}

// TODO: Store metrics on disk
func (p *Proxy) incVisitMetric(service, path string) {
	if _, ok := p.VisitsPerServiceAndPath[service]; !ok {
		p.VisitsPerServiceAndPath[service] = make(map[string]int)
	}
	p.VisitsPerServiceAndPath[service][path]++
	p.TotalVisits++
}

// TODO: Store metrics on disk
func (p *Proxy) incStatusMetric(service, path string, status int) {
	if _, ok := p.StatusPerServiceAndPath[service]; !ok {
		p.StatusPerServiceAndPath[service] = make(map[string]map[int]int)
	}
	if _, ok := p.StatusPerServiceAndPath[service][path]; !ok {
		p.StatusPerServiceAndPath[service][path] = make(map[int]int)
	}
	p.StatusPerServiceAndPath[service][path][status]++
	if status >= 400 {
		p.TotalErrors++
	}
}

// SetAllowedHosts sets the allowed hosts.
func (p *Proxy) SetAllowedHosts(hosts []string) {
	p.allowedHosts = hosts
	p.certMan.HostPolicy = autocert.HostWhitelist(p.allowedHosts...)
}

// AddRule adds a new rule to the proxy engine.
func (p *Proxy) AddRule(name string, match string, destinationPort int) error {
	p.ruleMutex.Lock()
	defer p.ruleMutex.Unlock()

	remoteURL, err := url.Parse(
		fmt.Sprintf("http://127.0.0.1:%d/", destinationPort),
	)
	if err != nil {
		return fmt.Errorf("could not parse target url: %w", err)
	}
	reverseProxy := httputil.NewSingleHostReverseProxy(remoteURL)
	reverseProxy.Transport = &proxyTransport{
		logger:            p.logger,
		proxy:             p,
		errorServerHeader: p.errorServerHeader,
		errorBody:         p.errorBody,
	}
	reverseProxy.Transport = http.DefaultTransport
	p.handler.Rules = append(p.handler.Rules, rule{
		Name:        name,
		Match:       match,
		Destination: destinationPort,
		Proxy:       reverseProxy,
	})

	return nil
}

// ClearRules clears all rules from the proxy.
func (p *Proxy) ClearRules() {
	p.ruleMutex.Lock()
	defer p.ruleMutex.Unlock()
	p.handler.Rules = []rule{}
}

// ListenAndServe starts the proxy.
func (p *Proxy) ListenAndServe() {
	serverFailed := make(chan interface{})
	tlsServerFailed := make(chan interface{})
	go (func() {
		defer close(serverFailed)

		var handler http.Handler
		if p.httpsRedirect {
			handler = p.certMan.HTTPHandler(nil)
		} else {
			handler = p.handler
		}
		server := &http.Server{
			Addr:           ":8000",
			Handler:        handler,
			ReadTimeout:    90 * time.Second,
			WriteTimeout:   90 * time.Second,
			MaxHeaderBytes: 1 << 20,
		}
		err := server.ListenAndServe()
		if err != nil {
			p.logger.Fatal("could not start server,", err)
		}
	})()
	go (func() {
		defer close(tlsServerFailed)

		server := &http.Server{
			Addr:    ":8443",
			Handler: p.handler,
			TLSConfig: &tls.Config{
				GetCertificate: p.certMan.GetCertificate,
			},
			ReadTimeout:    90 * time.Second,
			WriteTimeout:   90 * time.Second,
			MaxHeaderBytes: 1 << 20,
		}
		err := server.ListenAndServeTLS("", "")
		if err != nil {
			p.logger.Fatal("could not start tls server,", err)
		}
	})()

	go (func() {
		server := &http.Server{
			Addr: ":9090",
			Handler: metricsHandler{
				proxy: p,
				rules: p.handler.Rules,
			},
			ReadTimeout:    90 * time.Second,
			WriteTimeout:   90 * time.Second,
			MaxHeaderBytes: 1 << 20,
		}
		err := server.ListenAndServe()
		if err != nil {
			p.logger.Fatal("could not serve metrics,", err)
		}
	})()

	stop := make(chan os.Signal, 1)
	signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
	select {
	case <-stop:
	case <-serverFailed:
	case <-tlsServerFailed:
	}
}

type rule struct {
	Name        string
	Match       string
	Destination int
	Proxy       *httputil.ReverseProxy
}

type proxyTransport struct {
	http.RoundTripper
	logger            *log.Logger
	proxy             *Proxy
	errorServerHeader []string
	errorBody         []byte
}

func (t *proxyTransport) RoundTrip(r *http.Request) (*http.Response, error) {
	resp, err := t.RoundTripper.RoundTrip(r)
	if err != nil {
		// t.logger.Printf("error when talking to service: %s", err.Error())
		resp = &http.Response{
			Status:     "500 INTERNAL SERVER ERROR",
			StatusCode: 500,
			Proto:      r.Proto,
			ProtoMajor: r.ProtoMajor,
			ProtoMinor: r.ProtoMinor,
			Header: http.Header{
				"Server": t.errorServerHeader,
			},
			Body:             ioutil.NopCloser(bytes.NewBuffer(t.errorBody)),
			ContentLength:    0,
			TransferEncoding: r.TransferEncoding,
			Close:            true,
			Uncompressed:     false,
			Trailer:          http.Header{},
			Request:          nil,
			TLS:              r.TLS,
		}
	}

	t.proxy.incStatusMetric("valkyr", r.URL.String(), resp.StatusCode)

	return resp, nil
}

type proxyHandler struct {
	logger    *log.Logger
	ruleMutex *sync.RWMutex
	proxy     *Proxy
	Rules     []rule
}

// TODO: Cleanup matching logic.
func (h proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	h.ruleMutex.RLock()
	defer h.ruleMutex.RUnlock()

	reqHost := strings.Split(r.Host, ":")[0]

	pathParts := []string{"/"}
	if r.URL.Path != "/" {
		pathParts = strings.Split(r.URL.Path, "/")
		pathParts[0] = "/"
	}

	longest := -1
	var matched rule
	for _, rule := range h.Rules {
		ruleHostPath := strings.Split(rule.Match, "/")
		ruleHost := ruleHostPath[0]
		ruleHostPath[0] = "/"

		if ruleHost != reqHost {
			continue
		}

		if len(ruleHostPath) <= len(pathParts) {
			for t := range ruleHostPath {
				if ruleHostPath[t] != pathParts[t] {
					break
				} else {
					if t > longest {
						longest = t
						matched = rule
					}
				}
			}
		}
	}

	if longest > -1 {
		parts := strings.Split(r.URL.Path, "/")
		if len(parts) == longest+1 {
			/*h.logger.Printf(
				"redirecting to %s:%s, missing root", matched.Name, r.URL.String()+"/",
			)*/
			h.proxy.incStatusMetric(matched.Name, r.URL.String(), http.StatusMovedPermanently)
			http.Redirect(w, r, r.URL.String()+"/", http.StatusMovedPermanently)
			return
		}
		path := "/" + strings.Join(parts[longest+1:], "/")
		r.URL.Path = path

		h.proxy.incVisitMetric(matched.Name, r.URL.String())
		// h.logger.Printf("proxying to %s:%s", matched.Name, path)
		matched.Proxy.ServeHTTP(w, r)
		return
	}

	h.proxy.incVisitMetric("valkyr", r.URL.String())
	h.proxy.incStatusMetric("valkyr", r.URL.String(), http.StatusNotFound)
	http.Error(w, "404 NOT FOUND", http.StatusNotFound)
}

type metricsHandler struct {
	proxy *Proxy
	rules []rule
}

func (h metricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if r.URL.String() == "/metrics" {
		w.WriteHeader(http.StatusOK)
		w.Write([]byte("services:\n"))
		services := make([]string, 0, 10)
		services = append(services, "valkyr")
		for _, r := range h.rules {
			services = append(services, r.Name)
		}
		for _, service := range services {
			w.Write([]byte(fmt.Sprintf("\t%s\n", service)))
			for path, count := range h.proxy.VisitsPerServiceAndPath[service] {
				w.Write([]byte(fmt.Sprintf("\t\t%s: %d\n", path, count)))
				for statusCode, statusCount := range h.proxy.StatusPerServiceAndPath[service][path] {
					w.Write([]byte(fmt.Sprintf("\t\t\t%d: %d\n", statusCode, statusCount)))
				}
			}
		}
		w.Write([]byte("---\n"))
		w.Write([]byte(fmt.Sprintf("total_visits: %d\n", h.proxy.TotalVisits)))
		w.Write([]byte(fmt.Sprintf("total_errors: %d\n", h.proxy.TotalErrors)))
	} else if r.URL.String() == "/healthcheck" {
		w.WriteHeader(http.StatusOK)
		w.Write([]byte("OK\n"))
		return
	}

}