package sqlite3

import "core:c"
import "core:mem"
import "core:slice"
import "core:strings"

import "bindings"

DB :: struct {
	_conn:     bindings.Connection,
	allocator: mem.Allocator,
}

Rows :: struct {
	_stmt:     bindings.Statement,
	allocator: mem.Allocator,
}

Error :: union {
	DB_Error,
	Stmt_Error,
}

DB_Error :: enum {
	None = 0,
	Error,
	Busy,
	Cant_Open,
	Disk_Full,
	No_Memory,
	Not_A_Database,
	Permission_Denied,
	Read_Only,
}

Stmt_Error :: enum {
	None = 0,
	Error,
	Constraint_Violation,
	Data_Too_Big,
	Argument_Mismatch,
	Invalid_Datatype,
	Index_Out_Of_Range,
}

Open_Flags :: distinct bit_set[Open_Flag;i32]
Open_Flag :: enum i32 {
	Create     = i32(bindings.OpenFlag.Create),
	Read_Only  = i32(bindings.OpenFlag.Readonly),
	Read_Write = i32(bindings.OpenFlag.Readwrite),
}

@(private)
db_result_to_error :: proc(res: bindings.Result) -> Error {
	err: Error
	#partial switch res {
	case .Ok:
		err = nil
	case .Perm:
		err = DB_Error.Permission_Denied
	case .Busy:
		err = DB_Error.Busy
	case .Cantopen:
		err = DB_Error.Cant_Open
	case .Notadb:
		err = DB_Error.Not_A_Database
	case .Nomem:
		err = DB_Error.No_Memory
	case .Readonly:
		err = DB_Error.Read_Only
	case .Constraint:
		err = Stmt_Error.Constraint_Violation
	case .Toobig:
		err = Stmt_Error.Data_Too_Big
	case .Mismatch:
		err = Stmt_Error.Argument_Mismatch
	case .Range:
		err = Stmt_Error.Index_Out_Of_Range
	case:
		err = DB_Error.Error
	}
	return err
}

// TODO: Refactor
// - [ ] Make sure to separate concerns, query to get data, exec to insert/update/delete data
// - [x] Implement errors
// - [x] Wrap bindings.connection in DB struct instead of type alias
// - [x] Look over allocator, store allocator in DB struct
// - [x] Look over memory ownership
// - [x] Open flags

open :: proc(
	filename: string,
	flags := Open_Flags{.Read_Write, .Create},
	allocater := context.allocator,
) -> (
	^DB,
	Error,
) {
	db, err := new(DB, allocater)
	if err != nil {
		return nil, DB_Error.No_Memory
	}
	db.allocator = allocater

	c_filename := strings.clone_to_cstring(filename, context.temp_allocator)
	if res := bindings.Result(bindings.open_v2(c_filename, &db._conn, transmute(bindings.OpenFlags)flags, nil));
	   res != .Ok {
		free(db, allocater)
		return nil, db_result_to_error(res)
	}

	return db, nil
}

close :: proc(db: ^DB) -> Error {
	if res := bindings.Result(bindings.close(db._conn)); res != .Ok {
		return db_result_to_error(res)
	}

	free(db, db.allocator)

	return nil
}

query :: proc(db: ^DB, sql: string, args: ..any) -> (^Rows, Error) {
	rows, err := new(Rows, db.allocator)
	if err != nil {
		return nil, DB_Error.No_Memory
	}
	rows.allocator = db.allocator

	c_sql := strings.clone_to_cstring(sql, context.temp_allocator)

	if res := bindings.Result(bindings.prepare_v2(db._conn, c_sql, -1, &rows._stmt, nil)); res != .Ok {
		return nil, db_result_to_error(res)
	}

	for arg, i in args {
		switch v in arg {
		case int:
			if res := bindings.Result(bindings.bind_int(rows._stmt, c.int(i + 1), c.int(v))); res != .Ok {
				return nil, db_result_to_error(res)
			}
		case f64:
			if res := bindings.Result(bindings.bind_double(rows._stmt, c.int(i + 1), c.double(v))); res != .Ok {
				return nil, db_result_to_error(res)
			}
		case string:
			if res := bindings.Result(
				bindings.bind_text(
					rows._stmt,
					c.int(i + 1),
					strings.clone_to_cstring(v, context.temp_allocator),
					c.int(len(v)),
					nil,
				),
			); res != .Ok {
				return nil, db_result_to_error(res)
			}
		case []u8:
			if res := bindings.Result(bindings.bind_blob(rows._stmt, c.int(i + 1), raw_data(v), c.int(len(v)), nil));
			   res != .Ok {
				return nil, db_result_to_error(res)
			}
		case:
			return nil, Stmt_Error.Invalid_Datatype
		}
	}

	return rows, nil
}

exec :: proc(db: ^DB, sql: string) -> Error {
	c_sql := strings.clone_to_cstring(sql, context.temp_allocator)
	if res := bindings.Result(bindings.exec(db._conn, c_sql, nil, nil, nil)); res != .Ok {
		return db_result_to_error(res)
	}
	return nil
}

rows_next :: proc(rows: ^Rows) -> (bool, Error) {
	if res := bindings.Result(bindings.step(rows._stmt)); res != .Ok && res != .Row {
		if res != .Done {
			return false, db_result_to_error(res)
		}
		return false, nil
	}
	return true, nil
}

rows_scan :: proc(rows: ^Rows, columns: ..any) -> Error {
	for &col, i in columns {
		datatype := bindings.column_type(rows._stmt, c.int(i))
		switch datatype {
		case .Integer:
			col_int, ok := col.(^int)
			if !ok {
				return Stmt_Error.Invalid_Datatype
			}
			col_int^ = int(bindings.column_int(rows._stmt, c.int(i)))
		case .Float:
			col_float, ok := col.(^f64)
			if !ok {
				return Stmt_Error.Invalid_Datatype
			}
			col_float^ = f64(bindings.column_double(rows._stmt, c.int(i)))
		case .Text:
			col_str, ok := col.(^string)
			if !ok {
				return Stmt_Error.Invalid_Datatype
			}
			raw_str := bindings.column_text(rows._stmt, c.int(i))
			err: mem.Allocator_Error
			col_str^, err = strings.clone_from_cstring(raw_str, context.temp_allocator)
			if err != nil {
				return DB_Error.No_Memory
			}
		case .Blob:
			col_bytes, ok := col.(^[]u8)
			if !ok {
				return Stmt_Error.Invalid_Datatype
			}
			n_bytes := bindings.column_bytes(rows._stmt, c.int(i))
			raw_bytes := bindings.column_blob(rows._stmt, c.int(i))
			col_bytes^ = slice.bytes_from_ptr(raw_bytes, int(n_bytes))
		case .Null:
			continue
		}
	}
	return nil
}

rows_close :: proc(rows: ^Rows) -> Error {
	if rows == nil {
		return DB_Error.Error
	}

	res := bindings.Result(bindings.finalize(rows._stmt))
	free(rows, rows.allocator)
	if res != .Ok {
		return db_result_to_error(res)
	}

	return nil
}

/*
	if true {
		conn: bindings.Connection
		fmt.printf("bindings open: %s\n", bindings.open(cstring("test.db"), &conn))
		stmt: bindings.Statement
		fmt.printf(
			"bindings prepare: %s\n",
			bindings.prepare(
				conn,
				cstring("create table t1(x integer primary key asc,y);"),
				-1,
				&stmt,
				nil,
			),
		)
		fmt.printf("bindings step: %s\n", bindings.step(stmt))
		fmt.printf("bindings finalize: %s\n", bindings.finalize(stmt))
		fmt.printf(
			"bindings exec: %s\n",
			bindings.exec(
				conn,
				cstring(
					"insert into t1 (y) values (1); insert into t1 (y) values (2); select * from t1;",
				),
				exec_proc,
				conn,
				nil,
			),
		)
		errmsg: cstring
		if bindings.exec(conn, cstring("insert into t2 (y) values (1)"), nil, nil, &errmsg) != .Ok {
			fmt.printf("bindings exec err: %#v\n", errmsg)
			bindings.free(rawptr(errmsg))
		}
		fmt.printf("bindings close: %s\n", bindings.close(conn))
		os.exit(1)
	}
	*/