package day_08

import "core:log"
import "core:os"
import "core:sort"
import "core:strconv"
import "core:strings"

Input :: struct {
	fuse_boxes: [][3]int,
}

Result1 :: int
Result2 :: int

// --- Input --- //
print_input :: proc(input: Input) {
	log.infof("%#v", input)
}

parse_input_file :: proc(filepath: string) -> Input {
	input: Input

	raw_data, ok := os.read_entire_file_from_filename(filepath)
	if !ok {
		panic("oh no, could not read file")
	}
	defer delete(raw_data)

	lines, err := strings.split_lines(strings.trim_space(string(raw_data)))
	if err != .None {
		panic("oh no, failed splitting into lines")
	}
	defer delete(lines)

	input.fuse_boxes = make([][3]int, len(lines))
	for line, i in lines {
		nums := strings.split(line, ",")
		defer delete(nums)

		input.fuse_boxes[i] = {
			strconv.parse_int(nums[0]) or_else 0,
			strconv.parse_int(nums[1]) or_else 0,
			strconv.parse_int(nums[2]) or_else 0,
		}
	}

	return input
}

free_input :: proc(input: ^Input) {
	delete(input.fuse_boxes)
}
// --- Input --- //


// --- Helpers --- //
// --- Helpers --- //


// --- Task 1 --- //
Edge :: struct {
	a, b: int,
	dist: int,
}

Disjoint_Set :: struct {
	parent: []int,
	size:   []int,
}

disjoint_set_init :: proc(ds: ^Disjoint_Set, count: int) {
	ds.parent = make([]int, count)
	ds.size = make([]int, count)

	for i in 0 ..< count {
		ds.parent[i] = i
		ds.size[i] = 1
	}
}

disjoint_set_destroy :: proc(ds: ^Disjoint_Set) {
	delete(ds.parent)
	ds.parent = nil
	delete(ds.size)
	ds.size = nil
}

disjoint_set_find :: proc(ds: ^Disjoint_Set, x: int) -> int {
	if ds.parent[x] != x {
		ds.parent[x] = disjoint_set_find(ds, ds.parent[x])
	}
	return ds.parent[x]
}

disjoint_set_union :: proc(ds: ^Disjoint_Set, a, b: int) {
	sa := disjoint_set_find(ds, a)
	sb := disjoint_set_find(ds, b)
	if sa == sb {
		return
	}
	if ds.size[sa] < ds.size[sb] {
		sa, sb = sb, sa
	}
	ds.parent[sb] = sa
	ds.size[sa] += ds.size[sb]
}

run_task1 :: proc(input: Input, iteration_count: int = 1000) -> Result1 {
	result: Result1

	edges := make([dynamic]Edge, 0, len(input.fuse_boxes))
	defer delete(edges)

	for i in 0 ..< len(input.fuse_boxes) {
		for j in i + 1 ..< len(input.fuse_boxes) {
			a := input.fuse_boxes[i]
			b := input.fuse_boxes[j]
			c := a - b
			c *= c
			append(&edges, Edge{a = i, b = j, dist = c[0] + c[1] + c[2]})
		}
	}

	sort.quick_sort_proc(edges[:], proc(a, b: Edge) -> int {
		return a.dist - b.dist
	})

	ds: Disjoint_Set
	disjoint_set_init(&ds, len(input.fuse_boxes))
	defer disjoint_set_destroy(&ds)

	steps := 0
	for e in edges {
		disjoint_set_union(&ds, e.a, e.b)
		steps += 1
		if steps >= iteration_count {
			break
		}
	}

	component_sizes := make([dynamic]int, 0, len(ds.size))
	defer delete(component_sizes)

	for i in 0 ..< len(input.fuse_boxes) {
		if ds.parent[i] == i {
			append(&component_sizes, ds.size[i])
		}
	}

	sort.quick_sort(component_sizes[:])
	top_3 := component_sizes[len(component_sizes) - 3:]
	result += top_3[0] * top_3[1] * top_3[2]

	return result
}

print_result1 :: proc(result: Result1) {
	log.infof("Task 1: %d", result)
}
// --- Task 1 --- //


// --- Task 2 --- //
run_task2 :: proc(input: Input) -> Result2 {
	result: Result2

	edges := make([dynamic]Edge, 0, len(input.fuse_boxes))
	defer delete(edges)

	for i in 0 ..< len(input.fuse_boxes) {
		for j in i + 1 ..< len(input.fuse_boxes) {
			a := input.fuse_boxes[i]
			b := input.fuse_boxes[j]
			c := a - b
			c *= c
			append(&edges, Edge{a = i, b = j, dist = c[0] + c[1] + c[2]})
		}
	}

	sort.quick_sort_proc(edges[:], proc(a, b: Edge) -> int {
		return a.dist - b.dist
	})

	ds: Disjoint_Set
	disjoint_set_init(&ds, len(input.fuse_boxes))
	defer disjoint_set_destroy(&ds)


	components := len(input.fuse_boxes)
	last_edge: Edge
	for e in edges {
		sa := disjoint_set_find(&ds, e.a)
		sb := disjoint_set_find(&ds, e.b)
		if sa != sb {
			disjoint_set_union(&ds, e.a, e.b)
			components -= 1
			last_edge = e
			if components == 1 {
				break
			}
		}
	}

	log.debugf("%v", last_edge)
	log.debugf("%v - %v", input.fuse_boxes[last_edge.a], input.fuse_boxes[last_edge.b])

	result = input.fuse_boxes[last_edge.a][0] * input.fuse_boxes[last_edge.b][0]

	return result
}

print_result2 :: proc(result: Result2) {
	log.infof("Task 2: %d", result)
}
// --- Task 2 --- //

run :: proc() {
	input := parse_input_file("input/day_08.txt")
	defer free_input(&input)

	result1 := run_task1(input)
	print_result1(result1)

	result2 := run_task2(input)
	print_result2(result2)
}