package sqlite3

import "core:c"
import "core:fmt"
import "core:mem"
import "core:slice"
import "core:strings"

import sqlite3 "bindings"

DB :: sqlite3.Connection
Rows :: struct {
	_stmt: sqlite3.Statement,
}

open :: proc(filename: string) -> (DB, bool) {
	db: DB
	c_filename := strings.clone_to_cstring(filename)
	defer delete(c_filename)
	if res := sqlite3.Result(sqlite3.open(c_filename, &db)); res != .Ok {
		fmt.eprintf("sqlite3.open: %s\n", res)
		return nil, false
	}
	return db, true
}

close :: proc(db: DB) -> bool {
	if res := sqlite3.Result(sqlite3.close(db)); res != .Ok {
		fmt.eprintf("sqlite3.close: %s\n", res)
		return false
	}
	return true
}

query :: proc(db: DB, sql: string, args: ..any) -> (^Rows, bool) {
	rows := new(Rows)

	c_sql := strings.clone_to_cstring(sql)
	defer delete(c_sql)

	if res := sqlite3.Result(sqlite3.prepare(db, c_sql, -1, &rows._stmt, nil)); res != .Ok {
		fmt.eprintf("sqlite3.query: %s\n", res)
		return nil, false
	}

	for arg, i in args {
		switch v in arg {
		case int:
			sqlite3.bind_int(rows._stmt, c.int(i + 1), c.int(v))
		case f32:
			sqlite3.bind_double(rows._stmt, c.int(i + 1), c.double(v))
		case f64:
			sqlite3.bind_double(rows._stmt, c.int(i + 1), c.double(v))
		case string:
			sqlite3.bind_text(
				rows._stmt,
				c.int(i + 1),
				strings.unsafe_string_to_cstring(v),
				c.int(len(v)),
				nil,
			)
		case []u8:
			sqlite3.bind_blob(rows._stmt, c.int(i + 1), raw_data(v), c.int(len(v)), nil)
		case:
			fmt.printf("sqlite3.query: unsupported datatype %T\n", arg)
		}
	}

	return rows, true
}

exec :: proc(db: DB, sql: string, args: ..any) -> bool {
	rows, ok := query(db, sql, ..args)
	if !ok {
		fmt.printf("sqlite3.exec: query failed\n")
		return false
	}
	defer rows_close(rows)
	return rows_next(rows)
}

rows_next :: proc(rows: ^Rows) -> bool {
	if res := sqlite3.Result(sqlite3.step(rows._stmt)); res != .Ok && res != .Row {
		if res != .Done {
			fmt.eprintf("sqlite3.next: %s\n", res)
		}
		return false
	}
	return true
}

rows_scan :: proc(rows: ^Rows, columns: ..any) -> bool {
	for &col, i in columns {
		datatype := sqlite3.column_type(rows._stmt, c.int(i))
		switch datatype {
		case .Integer:
			col_int, ok := col.(^int)
			if !ok {
				fmt.eprintf("sqlite3.scan: invalid data type: expected int, got %T\n", col)
				return false
			}
			col_int^ = int(sqlite3.column_int(rows._stmt, c.int(i)))
		case .Float:
			col_float, ok := col.(^f32)
			if !ok {
				fmt.eprintf("sqlite3.scan: invalid data type: expected f32, got %T\n", col)
				return false
			}
			col_float^ = f32(sqlite3.column_double(rows._stmt, c.int(i)))
		case .Text:
			col_str, ok := col.(^string)
			if !ok {
				fmt.eprintf("sqlite3.scan: invalid data type: expected string, got %T\n", col)
				return false
			}
			raw_str := sqlite3.column_text(rows._stmt, c.int(i))
			err: mem.Allocator_Error
			col_str^, err = strings.clone_from_cstring(raw_str)
			if err != nil {
				fmt.eprintf("sqlite3.scan: err %#v\n", err)
				return false
			}
		case .Blob:
			col_bytes, ok := col.(^[]u8)
			if !ok {
				fmt.eprintf("sqlite3.scan: invalid data type: expected byte slice, got %T\n", col)
				return false
			}
			n_bytes := sqlite3.column_bytes(rows._stmt, c.int(i))
			raw_bytes := sqlite3.column_blob(rows._stmt, c.int(i))
			col_bytes^ = slice.bytes_from_ptr(raw_bytes, int(n_bytes))
		case .Null:
			fmt.printf("sqlite3.scan: column datatype is Null, skipping\n")
			return true
		}
	}
	return true
}

rows_close :: proc(rows: ^Rows) -> bool {
	if res := sqlite3.Result(sqlite3.finalize(rows._stmt)); res != .Ok {
		fmt.eprintf("sqlite3.close: %s\n", res)
		return false
	}
	free(rows)
	return true
}

// TODO: a "select" call that can fill a struct with columns, should support both single row and multiple rows through procedure groups

/*
	if true {
		conn: sqlite3.Connection
		fmt.printf("sqlite3 open: %s\n", sqlite3.open(cstring("test.db"), &conn))
		stmt: sqlite3.Statement
		fmt.printf(
			"sqlite3 prepare: %s\n",
			sqlite3.prepare(
				conn,
				cstring("create table t1(x integer primary key asc,y);"),
				-1,
				&stmt,
				nil,
			),
		)
		fmt.printf("sqlite3 step: %s\n", sqlite3.step(stmt))
		fmt.printf("sqlite3 finalize: %s\n", sqlite3.finalize(stmt))
		fmt.printf(
			"sqlite3 exec: %s\n",
			sqlite3.exec(
				conn,
				cstring(
					"insert into t1 (y) values (1); insert into t1 (y) values (2); select * from t1;",
				),
				exec_proc,
				conn,
				nil,
			),
		)
		errmsg: cstring
		if sqlite3.exec(conn, cstring("insert into t2 (y) values (1)"), nil, nil, &errmsg) != .Ok {
			fmt.printf("sqlite3 exec err: %#v\n", errmsg)
			sqlite3.free(rawptr(errmsg))
		}
		fmt.printf("sqlite3 close: %s\n", sqlite3.close(conn))
		os.exit(1)
	}
	*/