This commit is contained in:
Kijin-Seija
2024-04-07 19:10:38 +08:00
commit d7cbadf1db
22 changed files with 5900 additions and 0 deletions

15
src/caret.ts Normal file
View File

@ -0,0 +1,15 @@
import { type InsertCaretPlaceholderConfig } from './types'
export const DEFAULT_CARET_PLACEHOLDER = '__CARET_PLACEHOLDER__'
export let caretPlaceholder = DEFAULT_CARET_PLACEHOLDER
export const insertCaret = (sql: string, config?: InsertCaretPlaceholderConfig) => {
if (!config) return sql
if (config.text) caretPlaceholder = config.text
const sqlArr = sql.split('\n')
sqlArr[config.lineNumber - 1] = sqlArr[config.lineNumber - 1].slice(0, config.columnNumber - 1) +
caretPlaceholder +
sqlArr[config.lineNumber - 1].slice(config.columnNumber - 1)
return sqlArr.join('\n').trim()
}

28
src/index.ts Normal file
View File

@ -0,0 +1,28 @@
import { defaultEntities, defaultRules, defaultStmts } from './parse/default-rules'
import { parse } from './parse'
import { preprocess } from './preprocess'
import { type PluginSettings, type InsertCaretPlaceholderConfig } from './types'
import { defaultPreprocessorList } from './preprocess/default-preprocessor'
import { insertCaret } from './caret'
import { PostgresSQL } from 'dt-sql-parser'
export class DtSqlParserSemAnalysePlugin {
private readonly settings: PluginSettings = {}
constructor (settings?: PluginSettings) {
this.settings = settings ?? {}
}
public parse (sql: string, caret?: InsertCaretPlaceholderConfig) {
const sqlAfterInsertCaret = insertCaret(sql, caret)
const sqlAfterPreprocess = preprocess(sqlAfterInsertCaret, this.settings.preprocessor ?? defaultPreprocessorList)
const sqlParseResult = parse(
sqlAfterPreprocess,
this.settings.parse?.parser ?? new PostgresSQL(),
this.settings.parse?.stmts ?? defaultStmts,
this.settings.parse?.entities ?? defaultEntities,
this.settings.parse?.rules ?? defaultRules
)
return sqlParseResult
}
}

View File

@ -0,0 +1,78 @@
import { PostgreSQLParser } from 'dt-sql-parser/dist/lib/pgsql/PostgreSQLParser'
export const defaultStmts = [
// select statement
'selectstmt'
]
export const defaultEntities = [
// column_name directly. ex: select column1
'column_name',
// column_name indirectly. ex: select schema.column1
'columnref',
'table_name',
'view_name',
'function_name',
'schema_name',
'colid',
'attr_name',
'collabel',
'func_arg_expr'
]
export const defaultRules: Record<string, number[]> = {
select_target_column_simple: [
PostgreSQLParser.RULE_selectstmt,
PostgreSQLParser.RULE_select_clause,
PostgreSQLParser.RULE_target_list,
PostgreSQLParser.RULE_column_name
],
select_target_column_ref: [
PostgreSQLParser.RULE_selectstmt,
PostgreSQLParser.RULE_select_clause,
PostgreSQLParser.RULE_target_list,
PostgreSQLParser.RULE_columnref
],
select_target_function: [
PostgreSQLParser.RULE_selectstmt,
PostgreSQLParser.RULE_select_clause,
PostgreSQLParser.RULE_target_list,
PostgreSQLParser.RULE_function_name
],
select_column_alias: [
PostgreSQLParser.RULE_selectstmt,
PostgreSQLParser.RULE_select_clause,
PostgreSQLParser.RULE_target_list,
PostgreSQLParser.RULE_collabel
],
select_from_table: [
PostgreSQLParser.RULE_selectstmt,
PostgreSQLParser.RULE_select_clause,
PostgreSQLParser.RULE_from_clause,
PostgreSQLParser.RULE_table_name
],
select_from_view: [
PostgreSQLParser.RULE_selectstmt,
PostgreSQLParser.RULE_select_clause,
PostgreSQLParser.RULE_from_clause,
PostgreSQLParser.RULE_view_name
],
select_from_function: [
PostgreSQLParser.RULE_selectstmt,
PostgreSQLParser.RULE_select_clause,
PostgreSQLParser.RULE_from_clause,
PostgreSQLParser.RULE_function_name
],
column_ref_colid: [
PostgreSQLParser.RULE_columnref,
PostgreSQLParser.RULE_colid
],
column_ref_attr: [
PostgreSQLParser.RULE_columnref,
PostgreSQLParser.RULE_attr_name
],
function_arg_expr: [
PostgreSQLParser.RULE_function_name,
PostgreSQLParser.RULE_func_arg_expr
]
}

21
src/parse/index.ts Normal file
View File

@ -0,0 +1,21 @@
import { PostgresSQL } from 'dt-sql-parser'
import { SQLVisitor } from './visitor'
import { type SQLParseResult } from '../types'
import type BasicParser from 'dt-sql-parser/dist/parser/common/basicParser'
export function parse (
sql: string,
parser: BasicParser = new PostgresSQL(),
stmts: string[] = [],
entities: string[] = [],
rules: Record<string, number[]> = {}
): SQLParseResult {
const tree = parser.parse(sql)
console.log('tree', tree)
const visitor = new SQLVisitor()
stmts.forEach(stmt => { visitor.addStmt(stmt) })
entities.forEach(entity => { visitor.addEntity(entity) })
Object.keys(rules).forEach(name => { visitor.addRules(name, rules[name]) })
visitor.visit(tree)
return visitor.getResult()
}

128
src/parse/visitor.ts Normal file
View File

@ -0,0 +1,128 @@
import { type ParserRuleContext } from 'antlr4ts'
import { AbstractParseTreeVisitor, type PostgreSQLParserVisitor } from 'dt-sql-parser'
import { type ProgramContext, PostgreSQLParser } from 'dt-sql-parser/dist/lib/pgsql/PostgreSQLParser'
import { type Entity, type SQLParseResult, type Stmt } from '../types'
import { caretPlaceholder } from '../caret'
function withCaret (ctx: ParserRuleContext) {
return ctx.text.includes(caretPlaceholder)
}
export class SQLVisitor extends AbstractParseTreeVisitor<void> implements PostgreSQLParserVisitor<void> {
private result: SQLParseResult = {
stmtList: [],
nerestCaretEntityList: []
}
protected defaultResult = () => ({ list: [], nerestCaret: null })
public clear () {
this.result = { stmtList: [], nerestCaretEntityList: [] }
}
public getResult () {
return this.result
}
private readonly stmtStack: Stmt[] = []
private readonly entityStack: Entity[] = []
private readonly rules = new Map<string, number[]>()
private readonly stmtRules = new Map<number, string[]>()
private readonly entityRules = new Map<number, string[]>()
public addRules (name: string, rules: number[]) {
if (!this.stmtRules.has(rules[0]) && !this.entityRules.has(rules[0])) {
console.error(`待添加的规则${name}的起始节点未被注册为Statement/Entity请先注册或者调整规则起始节点为已注册的Statement/Entity`)
return
}
if (!this.entityRules.has(rules[rules.length - 1])) {
console.error(`待添加的规则${name}的结束节点未被注册为Entity请先注册或者调整规则结束节点为已注册的Entity`)
return
}
this.rules.set(name, [...rules])
if (this.stmtRules.has(rules[0])) this.stmtRules.get(rules[0])?.push(name)
else if (this.entityRules.has(rules[0])) this.entityRules.get(rules[0])?.push(name)
this.entityRules.get(rules[rules.length - 1])?.push(name)
}
public addEntity (name: string) {
this.entityRules.set((PostgreSQLParser as any)[`RULE_${name}`], [])
let isHitRule = false
const visitorName = `visit${name.slice(0, 1).toUpperCase()}${name.slice(1)}`
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const rules = this.entityRules.get((PostgreSQLParser as any)[`RULE_${name}`])!;
(this as any)[visitorName] = (ctx: ParserRuleContext) => {
const chain = this.getNodeChain(ctx)
for (const rule of rules) {
if (this.matchRules(chain, this.rules.get(rule))) {
const result: Entity = {
rule,
text: ctx.text,
type: ctx.ruleIndex,
caret: withCaret(ctx),
belongsToStmt: this.stmtStack[this.stmtStack.length - 1],
relatedEntities: {}
}
if (this.entityStack[this.entityStack.length - 1]) {
if (!this.entityStack[this.entityStack.length - 1].relatedEntities[rule]) this.entityStack[this.entityStack.length - 1].relatedEntities[rule] = []
this.entityStack[this.entityStack.length - 1].relatedEntities[rule].push(result)
} else {
if (!this.stmtStack[this.stmtStack.length - 1].relatedEntities[rule]) this.stmtStack[this.stmtStack.length - 1].relatedEntities[rule] = []
this.stmtStack[this.stmtStack.length - 1].relatedEntities[rule].push(result)
}
if (withCaret(ctx)) this.result.nerestCaretEntityList.push(result)
this.entityStack.push(result)
isHitRule = true
break
}
}
this.visitChildren(ctx)
if (isHitRule) this.entityStack.pop()
}
}
public addStmt (name: string) {
this.stmtRules.set((PostgreSQLParser as any)[`RULE_${name}`], [])
const visitorName = `visit${name.slice(0, 1).toUpperCase()}${name.slice(1)}`;
(this as any)[visitorName] = (ctx: ParserRuleContext) => {
this.stmtStack.push({
text: ctx.text,
type: ctx.ruleIndex,
caret: withCaret(ctx),
relatedEntities: {}
})
this.visitChildren(ctx)
const lastStmt = this.stmtStack.pop()
if (lastStmt) this.result.stmtList.push(lastStmt)
}
}
private getNodeChain (ctx: ParserRuleContext) {
let _ctx: ParserRuleContext | undefined = ctx
const result = []
while (_ctx) {
result.unshift(_ctx.ruleIndex)
_ctx = _ctx.parent
}
return result
}
private matchRules (chain: number[], ruleChain: number[] | undefined) {
// 只要ruleChain里面每个元素都出现在chain里面且顺序一致则返回true。否则返回false
if (!ruleChain) return false
let index = 0
for (let i = 0; i < ruleChain.length; i++) {
if (chain.indexOf(ruleChain[i]) < index) return false
else index = chain.indexOf(ruleChain[i])
}
return true
}
visitProgram (ctx: ProgramContext) {
this.visitChildren(ctx)
}
}

View File

@ -0,0 +1,24 @@
import { type Preprocessor } from '../types'
export const addSuffixForParseAlterFunctionProcessor: Preprocessor = (sql: string) => {
const suffix = 'RESET ALL'
const maxWords = 4
if (/alter( )+function/.test(sql.toLowerCase()) && sql.split(' ').filter(item => item).length < maxWords) {
return sql + '' + suffix
}
return sql
}
export const addSuffixForParseAlterTableProcessor: Preprocessor = (sql: string) => {
const suffix = 'SET WITHOUT OIDS'
const maxWords = 4
if (/alter( )+table/.test(sql.toLowerCase()) && sql.split(' ').filter(item => item).length < maxWords) {
return sql + '' + suffix
}
return sql
}
export const defaultPreprocessorList = [
addSuffixForParseAlterFunctionProcessor,
addSuffixForParseAlterTableProcessor
]

8
src/preprocess/index.ts Normal file
View File

@ -0,0 +1,8 @@
import { type Preprocessor } from '../types'
export function preprocess (
sql: string,
preprocessorList: Preprocessor[] = []
) {
return preprocessorList.reduce((preSql, preprocessor) => preprocessor(preSql), sql)
}

62
src/types.ts Normal file
View File

@ -0,0 +1,62 @@
import type BasicParser from 'dt-sql-parser/dist/parser/common/basicParser'
export interface InsertCaretPlaceholderConfig {
lineNumber: number
columnNumber: number
text?: string
}
export type Preprocessor = (sql: string) => string
export interface PluginSettings {
/**
* custom preprocessor
*
* ---
* 自定义预处理器
*/
preprocessor?: Preprocessor[]
/**
* custom parse logic
*
* ---
* 自定义解析逻辑
*/
parse?: {
parser?: BasicParser
stmts?: string[]
entities?: string[]
rules?: Record<string, number[]>
}
}
export interface SQLParseResultItem {
stmt?: Stmt
stmtType?: number
entity?: Entity
entityType?: number
rule?: string
caret: boolean
}
export interface Stmt {
text: string
type: number
caret: boolean
relatedEntities: Record<string, Entity[]>
}
export interface Entity {
rule: string
text: string
type: number
caret: boolean
belongsToStmt: Stmt
relatedEntities: Record<string, Entity[]>
}
export interface SQLParseResult {
stmtList: Stmt[]
nerestCaretEntityList: Entity[]
}