diff --git a/src/index.ts b/src/index.ts index 68ca30c..6aba7aa 100644 --- a/src/index.ts +++ b/src/index.ts @@ -18,7 +18,7 @@ export * from './lib/impala/ImpalaSqlParserVisitor'; export { SyntaxContextType } from './parser/common/basic-parser-types'; export type * from './parser/common/basic-parser-types'; -export type { SyntaxError, ParseError, ErrorHandler } from './parser/common/parseErrorListener'; +export type { SyntaxError, ParseError, ErrorListener } from './parser/common/parseErrorListener'; /** * @deprecated legacy, will be removed. diff --git a/src/parser/common/basicParser.ts b/src/parser/common/basicParser.ts index e232dc4..fdb69f6 100644 --- a/src/parser/common/basicParser.ts +++ b/src/parser/common/basicParser.ts @@ -17,7 +17,8 @@ import { WordRange, TextSlice, } from './basic-parser-types'; -import ParseErrorListener, { ParseError, ErrorHandler } from './parseErrorListener'; +import ParseErrorListener, { ParseError, ErrorListener } from './parseErrorListener'; +import { ErrorStrategy } from './errorStrategy'; interface IParser extends Parser { // Customized in our parser @@ -46,7 +47,7 @@ export default abstract class BasicParser< protected _parseErrors: ParseError[] = []; /** members for cache end */ - private _errorHandler: ErrorHandler = (error) => { + private _errorListener: ErrorListener = (error) => { this._parseErrors.push(error); }; @@ -90,7 +91,7 @@ export default abstract class BasicParser< * Create an antlr4 lexer from input. * @param input string */ - public createLexer(input: string, errorListener?: ErrorHandler) { + public createLexer(input: string, errorListener?: ErrorListener) { const charStreams = CharStreams.fromString(input.toUpperCase()); const lexer = this.createLexerFormCharStream(charStreams); if (errorListener) { @@ -104,7 +105,7 @@ export default abstract class BasicParser< * Create an antlr4 parser from input. * @param input string */ - public createParser(input: string, errorListener?: ErrorHandler) { + public createParser(input: string, errorListener?: ErrorListener) { const lexer = this.createLexer(input, errorListener); const tokenStream = new CommonTokenStream(lexer); const parser = this.createParserFromTokenStream(tokenStream); @@ -123,9 +124,10 @@ export default abstract class BasicParser< * @param errorListener listen parse errors and lexer errors. * @returns parseTree */ - public parse(input: string, errorListener?: ErrorHandler) { + public parse(input: string, errorListener?: ErrorListener) { const parser = this.createParser(input, errorListener); parser.buildParseTree = true; + parser.errorHandler = new ErrorStrategy(); return parser.program(); } @@ -141,7 +143,7 @@ export default abstract class BasicParser< this._lexer = this.createLexerFormCharStream(this._charStreams); this._lexer.removeErrorListeners(); - this._lexer.addErrorListener(new ParseErrorListener(this._errorHandler)); + this._lexer.addErrorListener(new ParseErrorListener(this._errorListener)); this._tokenStream = new CommonTokenStream(this._lexer); /** @@ -153,6 +155,7 @@ export default abstract class BasicParser< this._parser = this.createParserFromTokenStream(this._tokenStream); this._parser.buildParseTree = true; + this._parser.errorHandler = new ErrorStrategy(); return this._parser; } @@ -165,7 +168,7 @@ export default abstract class BasicParser< * @param errorListener listen errors * @returns parseTree */ - private parseWithCache(input: string, errorListener?: ErrorHandler) { + private parseWithCache(input: string, errorListener?: ErrorListener) { // Avoid parsing the same input repeatedly. if (this._parsedInput === input && !errorListener) { return this._parseTree; @@ -175,7 +178,7 @@ export default abstract class BasicParser< this._parsedInput = input; parser.removeErrorListeners(); - parser.addErrorListener(new ParseErrorListener(this._errorHandler)); + parser.addErrorListener(new ParseErrorListener(this._errorListener)); this._parseTree = parser.program(); @@ -317,6 +320,7 @@ export default abstract class BasicParser< const parser = this.createParserFromTokenStream(tokenStream); parser.removeErrorListeners(); parser.buildParseTree = true; + parser.errorHandler = new ErrorStrategy(); sqlParserIns = parser; c3Context = parser.program(); diff --git a/src/parser/common/errorStrategy.ts b/src/parser/common/errorStrategy.ts new file mode 100644 index 0000000..14e4fa3 --- /dev/null +++ b/src/parser/common/errorStrategy.ts @@ -0,0 +1,75 @@ +import { DefaultErrorStrategy } from 'antlr4ts/DefaultErrorStrategy'; +import { Parser } from 'antlr4ts/Parser'; +import { InputMismatchException } from 'antlr4ts/InputMismatchException'; +import { IntervalSet } from 'antlr4ts/misc/IntervalSet'; +import { ParserRuleContext } from 'antlr4ts/ParserRuleContext'; +import { RecognitionException } from 'antlr4ts/RecognitionException'; +import { Token } from 'antlr4ts/Token'; + +/** + * Base on DefaultErrorStrategy. + * The difference is that it assigns exception to the context.exception when it encounters error. + */ +export class ErrorStrategy extends DefaultErrorStrategy { + public recover(recognizer: Parser, e: RecognitionException): void { + // Mark the context as an anomaly + for ( + let context: ParserRuleContext | undefined = recognizer.context; + context; + context = context.parent + ) { + context.exception = e; + } + + // Error recovery + if ( + this.lastErrorIndex === recognizer.inputStream.index && + this.lastErrorStates && + this.lastErrorStates.contains(recognizer.state) + ) { + recognizer.consume(); + } + this.lastErrorIndex = recognizer.inputStream.index; + if (!this.lastErrorStates) { + this.lastErrorStates = new IntervalSet(); + } + this.lastErrorStates.add(recognizer.state); + let followSet: IntervalSet = this.getErrorRecoverySet(recognizer); + this.consumeUntil(recognizer, followSet); + } + + public recoverInline(recognizer: Parser): Token { + let e: RecognitionException; + if (this.nextTokensContext === undefined) { + e = new InputMismatchException(recognizer); + } else { + e = new InputMismatchException( + recognizer, + this.nextTokensState, + this.nextTokensContext + ); + } + + // Mark the context as an anomaly + for ( + let context: ParserRuleContext | undefined = recognizer.context; + context; + context = context.parent + ) { + context.exception = e; + } + + // Error recovery + let matchedSymbol = this.singleTokenDeletion(recognizer); + if (matchedSymbol) { + recognizer.consume(); + return matchedSymbol; + } + + if (this.singleTokenInsertion(recognizer)) { + return this.getMissingSymbol(recognizer); + } + + throw e; + } +} diff --git a/src/parser/common/parseErrorListener.ts b/src/parser/common/parseErrorListener.ts index 7d1bb10..d7deff1 100644 --- a/src/parser/common/parseErrorListener.ts +++ b/src/parser/common/parseErrorListener.ts @@ -25,16 +25,16 @@ export interface SyntaxError { } /** - * ErrorHandler will be invoked when it encounters a parsing error. + * ErrorListener will be invoked when it encounters a parsing error. * Includes lexical errors and parsing errors. */ -export type ErrorHandler = (parseError: ParseError, originalError: SyntaxError) => void; +export type ErrorListener = (parseError: ParseError, originalError: SyntaxError) => void; export default class ParseErrorListener implements ANTLRErrorListener { - private _errorHandler; + private _errorListener; - constructor(errorListener: ErrorHandler) { - this._errorHandler = errorListener; + constructor(errorListener: ErrorListener) { + this._errorListener = errorListener; } syntaxError( @@ -49,8 +49,8 @@ export default class ParseErrorListener implements ANTLRErrorListener { if (offendingSymbol && offendingSymbol.text !== null) { endCol = charPositionInLine + offendingSymbol.text.length; } - if (this._errorHandler) { - this._errorHandler( + if (this._errorListener) { + this._errorListener( { startLine: line, endLine: line, diff --git a/test/common/basicParser.test.ts b/test/common/basicParser.test.ts index 3e7382c..1321023 100644 --- a/test/common/basicParser.test.ts +++ b/test/common/basicParser.test.ts @@ -1,5 +1,5 @@ import { CommonTokenStream } from 'antlr4ts'; -import { ErrorHandler, FlinkSQL } from '../../src'; +import { ErrorListener, FlinkSQL } from '../../src'; import { FlinkSqlLexer } from '../../src/lib/flinksql/FlinkSqlLexer'; describe('BasicParser unit tests', () => { @@ -12,13 +12,13 @@ describe('BasicParser unit tests', () => { expect(lexer).not.toBeNull(); }); - test('Create lexer with errorHandler', () => { + test('Create lexer with errorListener', () => { const sql = '袋鼠云数栈UED团队'; const errors: any[] = []; - const errorHandler: ErrorHandler = (err) => { + const errorListener: ErrorListener = (err) => { errors.push(err); }; - const lexer = flinkParser.createLexer(sql, errorHandler); + const lexer = flinkParser.createLexer(sql, errorListener); const tokenStream = new CommonTokenStream(lexer); tokenStream.fill(); expect(errors.length).not.toBe(0); @@ -32,24 +32,24 @@ describe('BasicParser unit tests', () => { expect(parser).not.toBeNull(); }); - test('Create parser with errorHandler (lexer error)', () => { + test('Create parser with errorListener (lexer error)', () => { const sql = '袋鼠云数栈UED团队'; const errors: any[] = []; - const errorHandler: ErrorHandler = (err) => { + const errorListener: ErrorListener = (err) => { errors.push(err); }; - const parser = flinkParser.createParser(sql, errorHandler); + const parser = flinkParser.createParser(sql, errorListener); parser.program(); expect(errors.length).not.toBe(0); }); - test('Create parser with errorHandler (parse error)', () => { + test('Create parser with errorListener (parse error)', () => { const sql = 'SHOW TA'; const errors: any[] = []; - const errorHandler: ErrorHandler = (err) => { + const errorListener: ErrorListener = (err) => { errors.push(err); }; - const parser = flinkParser.createParser(sql, errorHandler); + const parser = flinkParser.createParser(sql, errorListener); parser.program(); expect(errors.length).not.toBe(0); }); @@ -57,10 +57,10 @@ describe('BasicParser unit tests', () => { test('Parse right input', () => { const sql = 'SELECT * FROM tb1'; const errors: any[] = []; - const errorHandler: ErrorHandler = (err) => { + const errorListener: ErrorListener = (err) => { errors.push(err); }; - const parseTree = flinkParser.parse(sql, errorHandler); + const parseTree = flinkParser.parse(sql, errorListener); expect(parseTree).not.toBeUndefined(); expect(parseTree).not.toBeNull(); @@ -70,10 +70,10 @@ describe('BasicParser unit tests', () => { test('Parse wrong input', () => { const sql = '袋鼠云数栈UED团队'; const errors: any[] = []; - const errorHandler: ErrorHandler = (err) => { + const errorListener: ErrorListener = (err) => { errors.push(err); }; - const parseTree = flinkParser.parse(sql, errorHandler); + const parseTree = flinkParser.parse(sql, errorListener); expect(parseTree).not.toBeUndefined(); expect(parseTree).not.toBeNull(); diff --git a/test/parser/flinksql/errorStrategy.test.ts b/test/parser/flinksql/errorStrategy.test.ts new file mode 100644 index 0000000..203a34e --- /dev/null +++ b/test/parser/flinksql/errorStrategy.test.ts @@ -0,0 +1,62 @@ +import FlinkSQL from '../../../src/parser/flinksql'; +import { FlinkSqlSplitListener } from '../../../src/parser/flinksql'; +import { FlinkSqlParserListener } from '../../../src/lib/flinksql/FlinkSqlParserListener'; + +const validSQL1 = `INSERT INTO country_page_view +VALUES ('Chinese', 'mumiao', 18), + ('Amercian', 'georage', 22);`; +const validSQL2 = 'SELECT * FROM tb;'; +const inValidSQL = 'CREATE TABLE VALUES;'; + +describe('FlinkSQL ErrorStrategy test', () => { + const flinkSQL = new FlinkSQL(); + test('begin inValid', () => { + const sql = [inValidSQL, validSQL1, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = flinkSQL.parse(sql, () => {}); + const splitListener = new FlinkSqlSplitListener(); + flinkSQL.listen(splitListener as FlinkSqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== statementCount - 2) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('middle inValid', () => { + const sql = [validSQL1, inValidSQL, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = flinkSQL.parse(sql, () => {}); + const splitListener = new FlinkSqlSplitListener(); + flinkSQL.listen(splitListener as FlinkSqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== 0) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('end inValid', () => { + const sql = [validSQL1, validSQL2, inValidSQL].join('\n'); + // parse with empty errorListener + const parseTree = flinkSQL.parse(sql, () => {}); + const splitListener = new FlinkSqlSplitListener(); + flinkSQL.listen(splitListener as FlinkSqlParserListener, parseTree); + + splitListener.statementsContext.map((item, index) => { + if (index !== 0 && index !== 1) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); +}); diff --git a/test/parser/hive/errorStrategy.test.ts b/test/parser/hive/errorStrategy.test.ts new file mode 100644 index 0000000..b56fd23 --- /dev/null +++ b/test/parser/hive/errorStrategy.test.ts @@ -0,0 +1,62 @@ +import HiveSQL from '../../../src/parser/hive'; +import { HiveSqlSplitListener } from '../../../src/parser/hive'; +import { HiveSqlParserListener } from '../../../src/lib/hive/HiveSqlParserListener'; + +const validSQL1 = `INSERT INTO country_page_view +VALUES ('Chinese', 'mumiao', 18), + ('Amercian', 'georage', 22);`; +const validSQL2 = 'SELECT * FROM tb;'; +const inValidSQL = 'CREATE TABLE VALUES;'; + +describe('HiveSQL ErrorStrategy test', () => { + const hiveSQL = new HiveSQL(); + test('begin inValid', () => { + const sql = [inValidSQL, validSQL1, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = hiveSQL.parse(sql, () => {}); + const splitListener = new HiveSqlSplitListener(); + hiveSQL.listen(splitListener as HiveSqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== statementCount - 2) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('middle inValid', () => { + const sql = [validSQL1, inValidSQL, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = hiveSQL.parse(sql, () => {}); + const splitListener = new HiveSqlSplitListener(); + hiveSQL.listen(splitListener as HiveSqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== 0) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('end inValid', () => { + const sql = [validSQL1, validSQL2, inValidSQL].join('\n'); + // parse with empty errorListener + const parseTree = hiveSQL.parse(sql, () => {}); + const splitListener = new HiveSqlSplitListener(); + hiveSQL.listen(splitListener as HiveSqlParserListener, parseTree); + + splitListener.statementsContext.map((item, index) => { + if (index !== 0 && index !== 1) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); +}); diff --git a/test/parser/impala/errorStrategy.test.ts b/test/parser/impala/errorStrategy.test.ts new file mode 100644 index 0000000..afacae2 --- /dev/null +++ b/test/parser/impala/errorStrategy.test.ts @@ -0,0 +1,62 @@ +import ImpalaSQL from '../../../src/parser/impala'; +import { ImpalaSqlSplitListener } from '../../../src/parser/impala'; +import { ImpalaSqlParserListener } from '../../../src/lib/impala/ImpalaSqlParserListener'; + +const validSQL1 = `INSERT INTO country_page_view +VALUES ('Chinese', 'mumiao', 18), + ('Amercian', 'georage', 22);`; +const validSQL2 = 'SELECT * FROM tb;'; +const inValidSQL = 'CREATE TABLE VALUES;'; + +describe('ImpalaSQL ErrorStrategy test', () => { + const impalaSQL = new ImpalaSQL(); + test('begin inValid', () => { + const sql = [inValidSQL, validSQL1, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = impalaSQL.parse(sql, () => {}); + const splitListener = new ImpalaSqlSplitListener(); + impalaSQL.listen(splitListener as ImpalaSqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== statementCount - 2) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('middle inValid', () => { + const sql = [validSQL1, inValidSQL, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = impalaSQL.parse(sql, () => {}); + const splitListener = new ImpalaSqlSplitListener(); + impalaSQL.listen(splitListener as ImpalaSqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== 0) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('end inValid', () => { + const sql = [validSQL1, validSQL2, inValidSQL].join('\n'); + // parse with empty errorListener + const parseTree = impalaSQL.parse(sql, () => {}); + const splitListener = new ImpalaSqlSplitListener(); + impalaSQL.listen(splitListener as ImpalaSqlParserListener, parseTree); + + splitListener.statementsContext.map((item, index) => { + if (index !== 0 && index !== 1) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); +}); diff --git a/test/parser/mysql/errorStrategy.test.ts b/test/parser/mysql/errorStrategy.test.ts new file mode 100644 index 0000000..4dca176 --- /dev/null +++ b/test/parser/mysql/errorStrategy.test.ts @@ -0,0 +1,62 @@ +import MySQL from '../../../src/parser/mysql'; +import { MysqlSplitListener } from '../../../src/parser/mysql'; +import { MySqlParserListener } from '../../../src/lib/mysql/MySqlParserListener'; + +const validSQL1 = `INSERT INTO country_page_view +VALUES ('Chinese', 'mumiao', 18), + ('Amercian', 'georage', 22);`; +const validSQL2 = 'SELECT * FROM tb;'; +const inValidSQL = 'CREATE TABLE VALUES;'; + +describe('MySQL ErrorStrategy test', () => { + const mysql = new MySQL(); + test('begin inValid', () => { + const sql = [inValidSQL, validSQL1, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = mysql.parse(sql, () => {}); + const splitListener = new MysqlSplitListener(); + mysql.listen(splitListener as MySqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== statementCount - 2) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('middle inValid', () => { + const sql = [validSQL1, inValidSQL, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = mysql.parse(sql, () => {}); + const splitListener = new MysqlSplitListener(); + mysql.listen(splitListener as MySqlParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== 0) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('end inValid', () => { + const sql = [validSQL1, validSQL2, inValidSQL].join('\n'); + // parse with empty errorListener + const parseTree = mysql.parse(sql, () => {}); + const splitListener = new MysqlSplitListener(); + mysql.listen(splitListener as MySqlParserListener, parseTree); + + splitListener.statementsContext.map((item, index) => { + if (index !== 0 && index !== 1) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); +}); diff --git a/test/parser/pgsql/errorStrategy.test.ts b/test/parser/pgsql/errorStrategy.test.ts new file mode 100644 index 0000000..049c3ec --- /dev/null +++ b/test/parser/pgsql/errorStrategy.test.ts @@ -0,0 +1,64 @@ +import PgSQL from '../../../src/parser/pgsql'; +import { PgSqlSplitListener } from '../../../src/parser/pgsql'; +import { PostgreSQLParserListener } from '../../../src/lib/pgsql/PostgreSQLParserListener'; + +const validSQL1 = `INSERT INTO country_page_view +VALUES ('Chinese', 'mumiao', 18), + ('Amercian', 'georage', 22);`; +const validSQL2 = 'SELECT * FROM tb;'; +const inValidSQL = 'CREATE TABLE'; + +describe('PgSQL ErrorStrategy test', () => { + const pgSQL = new PgSQL(); + + // TODO: handle unexpected case + // test('begin inValid', () => { + // const sql = [inValidSQL, validSQL1, validSQL2].join('\n'); + // // parse with empty errorListener + // const parseTree = pgSQL.parse(sql, () => {}); + // const splitListener = new PgSqlSplitListener(); + // pgSQL.listen(splitListener as PostgreSQLParserListener, parseTree); + + // const statementCount = splitListener.statementsContext.length; + // splitListener.statementsContext.map((item, index) => { + // if(index !== statementCount-1 && index !== statementCount - 2) { + // expect(item.exception).not.toBe(undefined); + // } else { + // expect(item.exception).toBe(undefined); + // } + // }) + // }); + + test('middle inValid', () => { + const sql = [validSQL1, inValidSQL, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = pgSQL.parse(sql, () => {}); + const splitListener = new PgSqlSplitListener(); + pgSQL.listen(splitListener as PostgreSQLParserListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== 0) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('end inValid', () => { + const sql = [validSQL1, validSQL2, inValidSQL].join('\n'); + // parse with empty errorListener + const parseTree = pgSQL.parse(sql, () => {}); + const splitListener = new PgSqlSplitListener(); + pgSQL.listen(splitListener as PostgreSQLParserListener, parseTree); + + splitListener.statementsContext.map((item, index) => { + if (index !== 0 && index !== 1) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); +}); diff --git a/test/parser/spark/errorStrategy.test.ts b/test/parser/spark/errorStrategy.test.ts new file mode 100644 index 0000000..d9b2b88 --- /dev/null +++ b/test/parser/spark/errorStrategy.test.ts @@ -0,0 +1,65 @@ +import SparkSQL from '../../../src/parser/spark'; +import { SparkSqlSplitListener } from '../../../src/parser/spark'; +import { SparkSqlParserListener } from '../../../src/lib/spark/SparkSqlParserListener'; + +const validSQL1 = `INSERT INTO country_page_view +VALUES ('Chinese', 'mumiao', 18), + ('Amercian', 'georage', 22);`; +const validSQL2 = 'SELECT * FROM tb;'; +const inValidSQL = 'CREATE TABLE;'; + +describe('SparkSQL ErrorStrategy test', () => { + const sparkSQL = new SparkSQL(); + + // TODO: handle unexpected case + // test('begin inValid', () => { + // const sql = [inValidSQL, validSQL1, validSQL2].join('\n'); + // // parse with empty errorListener + // const parseTree = sparkSQL.parse(sql, () => {}); + // const splitListener = new SparkSqlSplitListener(); + // sparkSQL.listen(splitListener as SparkSqlParserListener, parseTree); + + // const statementCount = splitListener.statementsContext.length; + // splitListener.statementsContext.map((item, index) => { + // if(index !== statementCount-1 && index !== statementCount - 2) { + // expect(item.exception).not.toBe(undefined); + // } else { + // expect(item.exception).toBe(undefined); + // } + // }) + // }); + + // TODO: handle unexpected case + // test('middle inValid', () => { + // const sql = [validSQL1, inValidSQL, validSQL2].join('\n'); + // // parse with empty errorListener + // const parseTree = sparkSQL.parse(sql, () => {}); + // const splitListener = new SparkSqlSplitListener(); + // sparkSQL.listen(splitListener as SparkSqlParserListener, parseTree); + + // const statementCount = splitListener.statementsContext.length; + // splitListener.statementsContext.map((item, index) => { + // if(index !== statementCount-1 && index !== 0) { + // expect(item.exception).not.toBe(undefined); + // } else { + // expect(item.exception).toBe(undefined); + // } + // }) + // }); + + test('end inValid', () => { + const sql = [validSQL1, validSQL2, inValidSQL].join('\n'); + // parse with empty errorListener + const parseTree = sparkSQL.parse(sql, () => {}); + const splitListener = new SparkSqlSplitListener(); + sparkSQL.listen(splitListener as SparkSqlParserListener, parseTree); + + splitListener.statementsContext.map((item, index) => { + if (index !== 0 && index !== 1) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); +}); diff --git a/test/parser/trinosql/errorStrategy.test.ts b/test/parser/trinosql/errorStrategy.test.ts new file mode 100644 index 0000000..9e8a581 --- /dev/null +++ b/test/parser/trinosql/errorStrategy.test.ts @@ -0,0 +1,62 @@ +import TrinoSQL from '../../../src/parser/trinosql'; +import { TrinoSqlSplitListener } from '../../../src/parser/trinosql'; +import { TrinoSqlListener } from '../../../src/lib/trinosql/TrinoSqlListener'; + +const validSQL1 = `INSERT INTO country_page_view +VALUES ('Chinese', 'mumiao', 18), + ('Amercian', 'georage', 22);`; +const validSQL2 = 'SELECT * FROM tb;'; +const inValidSQL = 'CREATE TABLE VALUES;'; + +describe('TrinoSQL ErrorStrategy test', () => { + const trinoSQL = new TrinoSQL(); + test('begin inValid', () => { + const sql = [inValidSQL, validSQL1, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = trinoSQL.parse(sql, () => {}); + const splitListener = new TrinoSqlSplitListener(); + trinoSQL.listen(splitListener as TrinoSqlListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== statementCount - 2) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('middle inValid', () => { + const sql = [validSQL1, inValidSQL, validSQL2].join('\n'); + // parse with empty errorListener + const parseTree = trinoSQL.parse(sql, () => {}); + const splitListener = new TrinoSqlSplitListener(); + trinoSQL.listen(splitListener as TrinoSqlListener, parseTree); + + const statementCount = splitListener.statementsContext.length; + splitListener.statementsContext.map((item, index) => { + if (index !== statementCount - 1 && index !== 0) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); + + test('end inValid', () => { + const sql = [validSQL1, validSQL2, inValidSQL].join('\n'); + // parse with empty errorListener + const parseTree = trinoSQL.parse(sql, () => {}); + const splitListener = new TrinoSqlSplitListener(); + trinoSQL.listen(splitListener as TrinoSqlListener, parseTree); + + splitListener.statementsContext.map((item, index) => { + if (index !== 0 && index !== 1) { + expect(item.exception).not.toBe(undefined); + } else { + expect(item.exception).toBe(undefined); + } + }); + }); +});