local run_formatter = function(text) local inner_text = string.sub(text, 4, -3) local job_result = require("plenary.job"):new({ command = "sqlfluff", args = { "format", "-" }, writer = inner_text, }):sync() if job_result == nil then return nil end -- add the surrounding r#"..."# back if #job_result == 1 then return { "r#\"" .. job_result[1] .. "#\"" } end local result = { "r#\"" } for _, line in ipairs(job_result) do table.insert(result, line) end table.insert(result, "\"#") return result end local embedded_sql = vim.treesitter.query.parse( "rust", [[ (macro_invocation (scoped_identifier path: (identifier) @path (#eq? @path "sqlx") name: (identifier) @name (#any-of? @name "query" "query_scalar" "query_as")) (token_tree (raw_string_literal) @sql) ) ]] ) local get_root = function(bufnr) local parser = vim.treesitter.get_parser(bufnr, "rust", {}) local tree = parser:parse()[1] return tree:root() end local format = function(bufnr) bufnr = bufnr or vim.api.nvim_get_current_buf() if vim.api.nvim_get_option_value("filetype", { buf = bufnr }) ~= "rust" then vim.notify("SQL format an only be used in Rust") return end local root = get_root(bufnr) local changes = {} for id, node in embedded_sql:iter_captures(root, bufnr, 0, -1) do local name = embedded_sql.captures[id] if name == "sql" then -- {start row, start col, end row, end col } local range = { node:range() } local indentation = string.rep(" ", range[2]) -- run the formatter on the node text local formatted = run_formatter(vim.treesitter.get_node_text(node, bufnr)) if formatted == nil then goto continue end -- add indentation for idx, line in ipairs(formatted) do if idx > 1 then formatted[idx] = indentation .. line end end -- add changes in reverse order table.insert(changes, 1, { start_row = range[1], start_col = range[2], end_row = range[3], end_col = range[4], formatted = formatted, }) ::continue:: end end for _, change in ipairs(changes) do vim.api.nvim_buf_set_text( bufnr, change.start_row, change.start_col, change.end_row, change.end_col, change.formatted) end end vim.api.nvim_create_user_command("SqlFormat", function() format() end, {})