🍯 Glaze

package router

import "core:fmt"
import "core:os"
import "core:strings"

import http "../"

Route_Kind :: enum {
	Exact,
	Prefix,
	Pattern,
}

Route :: struct {
	method:    http.Method,
	kind:      Route_Kind,
	path:      string,
	handler:   Route_Handler,
	user_data: rawptr,
}

Middleware :: struct {
	handler:   Middleware_Handler,
	user_data: rawptr,
}

Chain_Node :: struct {
	middleware:     Middleware_Handler,
	middleware_ctx: rawptr,
	next:           Route_Handler,
	next_ctx:       rawptr,
}

Router :: struct {
	routes:      [dynamic]Route,
	middlewares: [dynamic]Middleware,
	mime_map:    map[string]string,
}

Request :: struct {
	using _http_req: http.Request,
	rest_path:       string,
	params:          map[string]string,
}

Response :: http.Response

Route_Handler :: proc(req: Request, res: ^Response, ctx: rawptr)

Middleware_Handler :: proc(req: Request, res: ^Response, ctx: rawptr, next: Route_Handler, next_ctx: rawptr)

create :: proc() -> Router {
	r: Router
	r.routes = make([dynamic]Route, 0, 8)
	r.middlewares = make([dynamic]Middleware, 0, 8)

	// TODO: break out into dedicated file to enable clean imports
	// Might aswell expose mime handling
	r.mime_map = make(map[string]string)
	r.mime_map[".html"] = "text/html; charset=utf-8"
	r.mime_map[".css"] = "text/css; charset=utf-8"
	r.mime_map[".js"] = "application/javascript; charset=utf-8"
	r.mime_map[".json"] = "application/json; charset=utf-8"
	r.mime_map[".svg"] = "image/svg+xml"
	r.mime_map[".png"] = "image/png"
	r.mime_map[".jpg"] = "image/jpeg"
	r.mime_map[".jpeg"] = "image/jpeg"
	r.mime_map[".webp"] = "image/webp"
	r.mime_map[".gif"] = "image/gif"
	r.mime_map[".ico"] = "image/x-icon"
	r.mime_map[".txt"] = "text/plain; charset=utf-8"

	return r
}

destroy :: proc(r: ^Router) {
	delete(r.mime_map)
	delete(r.middlewares)
	delete(r.routes)
}

add :: proc(
	r: ^Router,
	method: http.Method,
	path: string,
	h: Route_Handler,
	kind: Route_Kind = .Exact,
	user_data: rawptr,
) {
	append(&r.routes, Route{method = method, kind = kind, path = path, handler = h, user_data = user_data})
}

get :: proc(r: ^Router, path: string, h: Route_Handler, kind: Route_Kind = .Exact, user_data: rawptr = nil) {
	add(r, .Get, path, h, kind, user_data)
}

post :: proc(r: ^Router, path: string, h: Route_Handler, kind: Route_Kind = .Exact, user_data: rawptr = nil) {
	add(r, .Post, path, h, kind, user_data)
}

use :: proc(r: ^Router, h: Middleware_Handler, user_data: rawptr = nil) {
	append(&r.middlewares, Middleware{handler = h, user_data = user_data})
}

@(private)
Static_Context :: struct {
	root:      string,
	base_path: string,
	mime_map:  map[string]string,
}

static :: proc(r: ^Router, path, filepath: string) {
	ctx := new(Static_Context) // TODO: cleanup of memory
	ctx.root, _ = os.get_absolute_path(filepath, context.allocator)
	ctx.base_path = strings.clone(path) // TODO: cleanup of memory
	ctx.mime_map = r.mime_map

	add(r, .Get, path, static_handler, .Prefix, ctx)
}

handle :: proc(r: ^Router) -> (http.Handler, rawptr) {
	return proc(req: http.Request, res: ^http.Response, ctx: rawptr) {
			router_handler(Request{_http_req = req}, res, ctx)
		}, r
}

verify_and_get_absolute_filepath :: proc(root, target: string) -> (string, bool) {
	requested_file, raw_file_path: string
	err: os.Error

	raw_file_path, err = os.join_path({root, target}, context.allocator)
	if err != nil {
		return "", false
	}
	defer delete(raw_file_path)

	requested_file, err = os.get_absolute_path(raw_file_path, context.allocator)
	if err != nil {
		return "", false
	}

	if !strings.has_prefix(requested_file, root) {
		return "", false
	}

	if !os.is_file(requested_file) {
		return "", false
	}

	return requested_file, true
}

@(private)
static_handler :: proc(req: Request, res: ^Response, ctx: rawptr) {
	static_context := (^Static_Context)(ctx)

	if len(req.target) <= len(static_context.base_path) + 1 {
		http.respond(res, .NotFound, nil)
		return
	}

	target := string(req.target[len(static_context.base_path) + 1:])
	requested_file, ok := verify_and_get_absolute_filepath(static_context.root, target)
	if !ok {
		http.respond(res, .NotFound, nil)
		return
	}
	defer delete(requested_file)

	ext := os.ext(requested_file)
	mime: string
	mime, ok = static_context.mime_map[ext]
	if !ok {
		mime = "application/octet-stream"
	}

	http.headers_add(&res.headers, "content-type", mime)
	http.headers_add(&res.headers, "cache-control", "public, max-age=3600")

	// TODO: don't read complete file into memory, stream/chunk it
	buf, err := os.read_entire_file_from_path(requested_file, context.allocator)
	if err != nil {
		fmt.eprintf("couldn't read file: %v\n", err)
		http.respond(res, .InternalServerError, nil)
		return
	}
	defer delete(buf)

	http.respond(res, .Ok, buf)
}

@(private)
chain_node_handler :: proc(req: Request, res: ^Response, ctx: rawptr) {
	n := (^Chain_Node)(ctx)
	n.middleware(req, res, n.middleware_ctx, n.next, n.next_ctx)
	free(ctx)
}

@(private)
build_chain :: proc(
	final_handler: Route_Handler,
	final_ctx: rawptr,
	middlewares: []Middleware,
) -> (
	Route_Handler,
	rawptr,
) {
	curr_handler := final_handler
	curr_ctx := final_ctx
	for m in middlewares {
		node := new(Chain_Node)
		node.middleware = m.handler
		node.middleware_ctx = m.user_data
		node.next = curr_handler
		node.next_ctx = curr_ctx
		curr_handler = chain_node_handler
		curr_ctx = node
	}
	return curr_handler, curr_ctx
}

@(private)
router_handler :: proc(req: Request, res: ^Response, user_data: rawptr) {
	r := (^Router)(user_data)

	req := req
	req.params = make(map[string]string)
	defer delete(req.params)

	for rt in r.routes {
		if rt.kind != .Exact {
			continue
		}

		if rt.method == req.method && rt.path == req.target {
			chain_handler, chain_ctx := build_chain(rt.handler, rt.user_data, r.middlewares[:])
			chain_handler(req, res, chain_ctx)
			return
		}
	}

	best_i := -1
	best_len := -1
	for rt, i in r.routes {
		if rt.kind != .Prefix {
			continue
		}
		if rt.method != req.method {
			// TODO: Support 405 in the future
			continue
		}
		if strings.has_prefix(req.target, rt.path) {
			path_len := len(rt.path)
			if path_len > best_len {
				best_len = path_len
				best_i = i
				req.rest_path = req.target[len(rt.path):]
			}
		}
	}
	if best_i >= 0 {
		chain_handler, chain_ctx := build_chain(r.routes[best_i].handler, r.routes[best_i].user_data, r.middlewares[:])
		chain_handler(req, res, chain_ctx)
		return
	}

	// TODO: rewrite logic to parse over windows of values instead of splitting
	target_parts := strings.split(req.target[1:], "/")
	defer delete(target_parts)
	Pattern_Kind :: enum {
		Param,
		Static,
		Wildcard,
	}
	Pattern_Elem :: struct {
		kind:  Pattern_Kind,
		name:  string,
		value: string,
	}
	best_pattern: []Pattern_Elem
	for rt, i in r.routes {
		if rt.kind != .Pattern {
			continue
		}
		if rt.method != req.method {
			// TODO: Support 405 in the future
			continue
		}

		// TODO: Do all this work at route setup

		// NOTE: Skip initial slash
		pattern_parts := strings.split(rt.path[1:], "/")
		defer delete(pattern_parts)

		if len(pattern_parts) > len(target_parts) {
			continue
		}

		pattern := make([]Pattern_Elem, len(pattern_parts))

		for p, j in pattern_parts {
			if p[0] == ':' {
				pattern[j].kind = .Param
				pattern[j].name = p[1:]
			} else if p[0] == '*' {
				if j != len(pattern_parts) - 1 {
					fmt.eprintf("router: wildcard must be last parameter -> %s\n", rt.path)
					continue
				}
				pattern[j].kind = .Wildcard
				pattern[j].name = p[1:]
			} else {
				pattern[j].kind = .Static
				pattern[j].name = p
			}
		}

		skip: bool
		for p, j in pattern {
			switch p.kind {
			case .Static:
				if target_parts[j] != p.name {
					skip = true
					break
				}
			case .Param:
			case .Wildcard:
			}
		}
		if skip {
			delete(pattern)
			continue
		}

		if len(pattern) > len(best_pattern) {
			best_i = i
			if best_pattern != nil {
				delete(best_pattern)
			}
			best_pattern = pattern
		} else {
			delete(pattern)
		}
	}
	if best_i >= 0 {
		for p, i in best_pattern {
			#partial switch p.kind {
			case .Param:
				req.params[p.name] = target_parts[i]
			case .Wildcard:
				count: int
				for c, j in req.target[1:] {
					if c == '/' {
						count += 1
						if count == i {
							req.params[p.name] = req.target[j + 2:]
							break
						}
					}
				}
			}
		}
		chain_handler, chain_ctx := build_chain(r.routes[best_i].handler, r.routes[best_i].user_data, r.middlewares[:])
		chain_handler(req, res, chain_ctx)
		delete(best_pattern)
		return
	}

	// TODO: Have a 404 handler
	http.respond(res, .NotFound, nil)
}