96 lines
2.3 KiB
Lua
96 lines
2.3 KiB
Lua
local run_formatter = function(text)
|
|
local inner_text = string.sub(text, 4, -3)
|
|
local job_result = require("plenary.job"):new({
|
|
command = "sqlfluff",
|
|
args = { "format", "-" },
|
|
writer = inner_text,
|
|
}):sync()
|
|
|
|
if job_result == nil then
|
|
return nil
|
|
end
|
|
-- add the surrounding r#"..."# back
|
|
if #job_result == 1 then
|
|
return { "r#\"" .. job_result[1] .. "#\"" }
|
|
end
|
|
local result = { "r#\"" }
|
|
for _, line in ipairs(job_result) do
|
|
table.insert(result, line)
|
|
end
|
|
table.insert(result, "\"#")
|
|
return result
|
|
end
|
|
|
|
local embedded_sql = vim.treesitter.query.parse(
|
|
"rust",
|
|
[[
|
|
(macro_invocation
|
|
(scoped_identifier
|
|
path: (identifier) @path (#eq? @path "sqlx")
|
|
name: (identifier) @name (#any-of? @name "query" "query_scalar" "query_as"))
|
|
|
|
(token_tree (raw_string_literal) @sql)
|
|
)
|
|
]]
|
|
)
|
|
|
|
local get_root = function(bufnr)
|
|
local parser = vim.treesitter.get_parser(bufnr, "rust", {})
|
|
local tree = parser:parse()[1]
|
|
return tree:root()
|
|
end
|
|
|
|
local format = function(bufnr)
|
|
bufnr = bufnr or vim.api.nvim_get_current_buf()
|
|
|
|
if vim.api.nvim_get_option_value("filetype", { buf = bufnr }) ~= "rust" then
|
|
vim.notify("SQL format an only be used in Rust")
|
|
return
|
|
end
|
|
|
|
local root = get_root(bufnr)
|
|
|
|
local changes = {}
|
|
for id, node in embedded_sql:iter_captures(root, bufnr, 0, -1) do
|
|
local name = embedded_sql.captures[id]
|
|
if name == "sql" then
|
|
-- {start row, start col, end row, end col }
|
|
local range = { node:range() }
|
|
local indentation = string.rep(" ", range[2])
|
|
-- run the formatter on the node text
|
|
local formatted = run_formatter(vim.treesitter.get_node_text(node, bufnr))
|
|
if formatted == nil then
|
|
goto continue
|
|
end
|
|
-- add indentation
|
|
for idx, line in ipairs(formatted) do
|
|
if idx > 1 then
|
|
formatted[idx] = indentation .. line
|
|
end
|
|
end
|
|
-- add changes in reverse order
|
|
table.insert(changes, 1, {
|
|
start_row = range[1],
|
|
start_col = range[2],
|
|
end_row = range[3],
|
|
end_col = range[4],
|
|
formatted = formatted,
|
|
})
|
|
::continue::
|
|
end
|
|
end
|
|
|
|
for _, change in ipairs(changes) do
|
|
vim.api.nvim_buf_set_text(
|
|
bufnr,
|
|
change.start_row,
|
|
change.start_col,
|
|
change.end_row,
|
|
change.end_col,
|
|
change.formatted)
|
|
end
|
|
end
|
|
|
|
vim.api.nvim_create_user_command("SqlFormat", function()
|
|
format()
|
|
end, {})
|