🍯 Glaze

#+private
package http

import "core:bytes"
import "core:fmt"
import "core:mem"
import "core:net"
import "core:strconv"
import "core:sync"
import "core:sys/posix"
import "core:testing"
import "core:time"

import "../../container/ring_buffer"

SHOULD_QUIT := false
should_quit_handler :: proc "c" (sig: posix.Signal) {
	SHOULD_QUIT = true
}

find_crlf :: proc {
	find_crlf_bytes,
	find_crlf_ring_buffer,
}

find_double_crlf :: proc {
	find_double_crlf_bytes,
	find_double_crlf_ring_buffer,
}

@(test)
test_ca_find_double_crlf_at_start :: proc(t: ^testing.T) {
	rb: ring_buffer.Ring_Buffer
	ring_buffer.init(&rb)
	defer ring_buffer.destroy(&rb)

	ring_buffer.append(&rb, {'F', 'o', 'o', 'b', 'a', 'r', '\r', '\n', '\r', '\n'})
	i := find_double_crlf_ring_buffer(&rb)
	testing.expect(t, i >= 0, "couldn't find double_crlf")
}

@(test)
test_ca_find_double_crlf_over_border :: proc(t: ^testing.T) {
	rb: ring_buffer.Ring_Buffer
	ring_buffer.init(&rb)
	defer ring_buffer.destroy(&rb)

	rb.off += rb.cap - 5
	ring_buffer.append(&rb, {'F', 'o', 'o', 'b', 'a', 'r', '\r', '\n', '\r', '\n'})
	i := find_double_crlf_ring_buffer(&rb)
	testing.expect(t, i >= 0, "couldn't find double_crlf")
}

Conn :: struct {
	sock:  net.TCP_Socket,
	carry: ring_buffer.Ring_Buffer,
}

send_response_header :: proc(sock: net.TCP_Socket, res: ^Response, should_close: bool) -> bool {
	if _, ok := headers_get_first(res.headers[:], "content-type"); !ok {
		headers_add(&res.headers, "content-type", "text/html; charset=utf-8")
	}
	if _, ok := headers_get_first(res.headers[:], "content-length"); !ok {
		buf: [10]u8
		v := len(res.body) if res.body != nil else 0
		headers_add(&res.headers, "content-length", strconv.write_int(buf[:], i64(v), 10))
	}
	defer {
		for _, i in res.headers {
			delete(res.headers[i].name)
			delete(res.headers[i].value)
		}
		delete(res.headers)
	}

	statusMessage := Status_Message

	// Build response head
	b: bytes.Buffer
	defer bytes.buffer_destroy(&b)
	bytes.buffer_write_string(&b, "HTTP/1.1 ")
	bytes.buffer_write_string(&b, statusMessage[res.code])
	bytes.buffer_write_string(&b, "\r\n")

	for h in res.headers {
		bytes.buffer_write_string(&b, h.name)
		bytes.buffer_write_string(&b, ": ")
		bytes.buffer_write_string(&b, h.value)
		bytes.buffer_write_string(&b, "\r\n")
	}

	bytes.buffer_write_string(&b, "Connection: ")
	bytes.buffer_write_string(&b, "close" if should_close else "keep-alive")
	bytes.buffer_write_string(&b, "\r\n")

	bytes.buffer_write_string(&b, "\r\n")

	return send_response(sock, bytes.buffer_to_bytes(&b))
}

send_response :: proc {
	send_response_bytes,
//send_response_stream,
}

// TODO: Overrideable to let user maybe handle it themself
send_response_error :: proc(sock: net.TCP_Socket, code: Status_Code) {
	res := Response {
		code = code,
	}
	send_response_header(sock, &res, true)
}

send_response_bytes :: proc(sock: net.TCP_Socket, buf: []u8) -> bool {
	off: int
	for off < len(buf) {
		n, err := net.send_tcp(sock, buf[off:])
		if err != nil {
			return false
		}
		off += n
	}
	return true
}

/*send_response_body_stream :: proc(...) {
}*/

READ_CHUNK :: 4096
read_into_carry :: proc(c: ^Conn) -> (ok: bool, timeout: bool) {
	tmp: [READ_CHUNK]u8
	n, err := net.recv_tcp(c.sock, tmp[:])
	if err != nil || n == 0 {
		return false, err == .Would_Block
	}
	ring_buffer.append(&c.carry, tmp[:n])
	return true, false
}

find_crlf_bytes :: proc(b: []u8) -> int {
	if len(b) < 2 {
		return -1
	}
	for i := 0; i <= len(b) - 2; i += 2 {
		if b[i] == '\r' && b[i + 1] == '\n' {
			return i
		}
	}
	return -1
}

find_crlf_ring_buffer :: proc(ca: ^ring_buffer.Ring_Buffer) -> int {
	if ca.len < 2 {
		return -1
	}
	for i := 0; i <= ca.len - 2; i += 1 {
		if ca.buf[(ca.off + i + 0) % ca.cap] == '\r' && ca.buf[(ca.off + i + 1) % ca.cap] == '\n' {
			return i
		}
	}
	return -1
}

find_double_crlf_bytes :: proc(b: []u8) -> int {
	if len(b) < 4 {
		return -1
	}
	for i := 0; i <= len(b) - 4; i += 4 {
		if b[i] == '\r' && b[i + 1] == '\n' && b[i + 2] == '\r' && b[i + 3] == '\n' {
			return i
		}
	}
	return -1
}

find_double_crlf_ring_buffer :: proc(ca: ^ring_buffer.Ring_Buffer) -> int {
	if ca.len < 4 {
		return -1
	}
	for i := 0; i <= ca.len - 4; i += 1 {
		if ca.buf[(ca.off + i + 0) % ca.cap] == '\r' &&
		   ca.buf[(ca.off + i + 1) % ca.cap] == '\n' &&
		   ca.buf[(ca.off + i + 2) % ca.cap] == '\r' &&
		   ca.buf[(ca.off + i + 3) % ca.cap] == '\n' {
			return i
		}
	}
	return -1
}

MAX_HEADER_BYTES :: 32 * 1024
MAX_BODY_BYTES :: 10 * 1024 * 1024

read_head_from_conn :: proc(c: ^Conn) -> (head: []u8, ok: bool, timeout: bool) {
	for {
		if c.carry.len > MAX_HEADER_BYTES {
			return nil, false, false
		}

		i := find_double_crlf(&c.carry)
		if i > 0 {
			head = ring_buffer.consume_front(&c.carry, i + 4)
			return head, true, false
		}

		if ok, timeout = read_into_carry(c); !ok {
			return nil, false, timeout
		}
	}
}

read_body_from_conn :: proc(c: ^Conn, want: int) -> (body: []u8, ok: bool, timeout: bool) {
	if want < 0 || want > MAX_BODY_BYTES {
		return nil, false, false
	}
	if want == 0 {
		return nil, true, false
	}

	for c.carry.len < want {
		if ok, timeout = read_into_carry(c); !ok {
			return nil, false, timeout
		}
		if c.carry.len > MAX_BODY_BYTES {
			return nil, false, false
		}
	}

	body = ring_buffer.consume_front(&c.carry, want)
	return body, true, false
}

parse_request :: proc(req: ^Request) -> bool {
	if req._raw_head == nil {
		return false
	}
	raw_head_block := req._raw_head[:len(req._raw_head) - 4]

	head_i := find_crlf(raw_head_block)
	if head_i <= 0 {
		return false
	}

	read_head_line(req, req._raw_head[:head_i])
	read_headers(req, req._raw_head[head_i + 2:])

	host_header, h_ok := headers_get_first(req.headers[:], "host")
	if !h_ok {
		if req.version == "HTTP/1.1" {
			return false
		}
	} else {
		req.host = host_header
	}

	cl_header, cl_ok := headers_get_first(req.headers[:], "content-length")
	if cl_ok {
		ok: bool
		if req.content_length, ok = strconv.parse_int(cl_header); !ok {
			return false
		}
	}

	return true
}

read_head_line :: proc(req: ^Request, head_line: []u8) -> bool {
	off, i: int

	i = bytes.index_byte(head_line[off:], ' ')
	if i <= 0 {
		return false
	}
	method := string(head_line[off:off + i])
	off += i + 1

	i = bytes.index_byte(head_line[off:], ' ')
	if i <= 0 {
		return false
	}
	target := string(head_line[off:off + i])
	off += i + 1

	version := string(head_line[off:])

	switch method {
	case "GET":
		req.method = .Get
	case "POST":
		req.method = .Post
	case "PUT":
		req.method = .Put
	case "DELETE":
		req.method = .Delete
	case "PATCH":
		req.method = .Patch
	case "HEAD":
		req.method = .Head
	case "OPTIONS":
		req.method = .Options
	case:
		fmt.eprintf("unknown method: %s\n", method)
		return false
	}

	req.target = target
	req.version = version

	return true
}

read_headers :: proc(req: ^Request, header_block: []u8) {
	off: int
	i := find_crlf(header_block)
	for i < len(header_block) && off < i {
		line := header_block[off:i]
		if j := bytes.index_byte(line, ':'); j >= 0 {
			headers_add(&req.headers, string(line[:j]), string(line[j + 1:]))
		}
		i = off + find_crlf(header_block[off:])
		off = i + 2
	}
}

should_close_connection :: proc(req: ^Request) -> bool {
	if header_has_token(req.headers[:], "connection", "close") {
		return true
	}

	if req.version == "HTTP/1.0" {
		return !header_has_token(req.headers[:], "connection", "keep-alive")
	}

	return false
}

MAX_REQUESTS_PER_CONN :: 100

Conn_Job :: struct {
	conn:               Conn,
	handler:            Handler,
	user_data:          rawptr,
	active_connections: ^int,
}

// NOTE: Takes ownership of Conn_Job
handle_conn_job :: proc(cj: ^Conn_Job) {
	main_allocator := context.allocator

	when ODIN_DEBUG {
		track: mem.Tracking_Allocator
		mem.tracking_allocator_init(&track, context.allocator)
		context.allocator = mem.tracking_allocator(&track)
	}

	when ODIN_DEBUG {
		fmt.printf("debug: ### enter conn_handle (n: %d/%d)\n", cj.active_connections^, MAX_CONNECTIONS)
	}

	conn_handle(&cj.conn, cj.handler, cj.user_data)

	sync.atomic_sub(cj.active_connections, 1)
	net.close(cj.conn.sock)
	free(cj, main_allocator)

	when ODIN_DEBUG {
		fmt.printf("debug: ### leave conn_handle (n: %d/%d)\n", cj.active_connections^, MAX_CONNECTIONS)
	}

	when ODIN_DEBUG {
		for _, leak in track.allocation_map {
			fmt.printf("%v leaked %m\n", leak.location, leak.size)
		}
		mem.tracking_allocator_destroy(&track)
	}
}

MAX_IDLE_TIMEOUT :: 15 * time.Second

conn_handle :: proc(c: ^Conn, req_handler: Handler, user_data: rawptr) {
	net.set_option(c.sock, .Receive_Timeout, MAX_IDLE_TIMEOUT)

	ring_buffer.init(&c.carry, 10)
	defer ring_buffer.destroy(&c.carry)

	for req_count := 0; req_count < MAX_REQUESTS_PER_CONN; req_count += 1 {
		raw_head_bytes, ok, timeout := read_head_from_conn(c)
		if !ok {
			if timeout {
				send_response_error(c.sock, .RequestTimeout)
			}
			return
		}

		req: Request
		req._raw_head = raw_head_bytes
		req.headers = make([dynamic]Header, 0, 4)
		if !parse_request(&req) {
			send_response_error(c.sock, .BadRequest)

			if req.host != "" {
				delete(req.host)
			}
			if req.headers != nil {
				for _, i in req.headers {
					delete(req.headers[i].name)
					delete(req.headers[i].value)
				}
				delete(req.headers)
			}
			if req._raw_head != nil {
				delete(req._raw_head)
			}
			if req.body != nil {
				delete(req.body)
			}
			return
		}

		should_close := should_close_connection(&req)

		defer {
			for _, i in req.headers {
				delete(req.headers[i].name)
				delete(req.headers[i].value)
			}
			delete(req.headers)
		}
		defer delete(req._raw_head)
		if req.content_length > 0 {
			req.body, ok, timeout = read_body_from_conn(c, req.content_length)
			if !ok {
				if timeout {
					send_response_error(c.sock, .RequestTimeout)
				}
				return
			}
		}
		defer delete(req.body)

		_, te_ok := headers_get_first(req.headers[:], "Transfer-Encoding")
		if te_ok {
			send_response_error(c.sock, .NotImplemented)
			return
		}

		// --- HANDLER CODE ---

		when ODIN_DEBUG {
			fmt.printf("debug: --> %s %s %s\n", req.method, req.target, req.version)
			fmt.printf("debug: headers -> %v\n", req.headers)
			// fmt.printf("debug: body -> %s\n", req.body)
		}

		res: Response
		res.headers = make([dynamic]Header, 0, 4)
		req_handler(req, &res, user_data)
		statusMessage := Status_Message
		when ODIN_DEBUG {
			fmt.printf("debug: <-- %s %s\n", req.target, statusMessage[res.code])
		}

		if !send_response_header(c.sock, &res, should_close) {
			fmt.eprintf("send header: error\n")
			return
		}
		if res.body != nil {
			if !send_response(c.sock, res.body) {
				fmt.eprintf("send body: error\n")
				return
			}
			delete(res.body)
		}

		free_all(context.temp_allocator)

		if should_close {
			break
		}
	}

	when ODIN_DEBUG {
		fmt.printf("debug: #!#! connection closed !#!#\n")
	}
}