package sanitize
import (
"bytes"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
"unicode/utf8"
)
type Part any
type Query struct {
Parts []Part
}
const replacementcharacterwidth = 3
func (q *Query ) Sanitize (args ...any ) (string , error ) {
argUse := make ([]bool , len (args ))
buf := &bytes .Buffer {}
for _ , part := range q .Parts {
var str string
switch part := part .(type ) {
case string :
str = part
case int :
argIdx := part - 1
if argIdx >= len (args ) {
return "" , fmt .Errorf ("insufficient arguments" )
}
arg := args [argIdx ]
switch arg := arg .(type ) {
case nil :
str = "null"
case int64 :
str = strconv .FormatInt (arg , 10 )
case float64 :
str = strconv .FormatFloat (arg , 'f' , -1 , 64 )
case bool :
str = strconv .FormatBool (arg )
case []byte :
str = QuoteBytes (arg )
case string :
str = QuoteString (arg )
case time .Time :
str = arg .Truncate (time .Microsecond ).Format ("'2006-01-02 15:04:05.999999999Z07:00:00'" )
default :
return "" , fmt .Errorf ("invalid arg type: %T" , arg )
}
argUse [argIdx ] = true
default :
return "" , fmt .Errorf ("invalid Part type: %T" , part )
}
buf .WriteString (str )
}
for i , used := range argUse {
if !used {
return "" , fmt .Errorf ("unused argument: %d" , i )
}
}
return buf .String (), nil
}
func NewQuery (sql string ) (*Query , error ) {
l := &sqlLexer {
src : sql ,
stateFn : rawState ,
}
for l .stateFn != nil {
l .stateFn = l .stateFn (l )
}
query := &Query {Parts : l .parts }
return query , nil
}
func QuoteString (str string ) string {
return "'" + strings .ReplaceAll (str , "'" , "''" ) + "'"
}
func QuoteBytes (buf []byte ) string {
return `'\x` + hex .EncodeToString (buf ) + "'"
}
type sqlLexer struct {
src string
start int
pos int
nested int
stateFn stateFn
parts []Part
}
type stateFn func (*sqlLexer ) stateFn
func rawState(l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case 'e' , 'E' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune == '\'' {
l .pos += width
return escapeStringState
}
case '\'' :
return singleQuoteState
case '"' :
return doubleQuoteState
case '$' :
nextRune , _ := utf8 .DecodeRuneInString (l .src [l .pos :])
if '0' <= nextRune && nextRune <= '9' {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos -width ])
}
l .start = l .pos
return placeholderState
}
case '-' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune == '-' {
l .pos += width
return oneLineCommentState
}
case '/' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune == '*' {
l .pos += width
return multilineCommentState
}
case utf8 .RuneError :
if width != replacementcharacterwidth {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
}
func singleQuoteState(l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case '\'' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune != '\'' {
return rawState
}
l .pos += width
case utf8 .RuneError :
if width != replacementcharacterwidth {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
}
func doubleQuoteState(l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case '"' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune != '"' {
return rawState
}
l .pos += width
case utf8 .RuneError :
if width != replacementcharacterwidth {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
}
func placeholderState(l *sqlLexer ) stateFn {
num := 0
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int (r - '0' )
} else {
l .parts = append (l .parts , num )
l .pos -= width
l .start = l .pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case '\\' :
_, width = utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
case '\'' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune != '\'' {
return rawState
}
l .pos += width
case utf8 .RuneError :
if width != replacementcharacterwidth {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
}
func oneLineCommentState(l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case '\\' :
_, width = utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
case '\n' , '\r' :
return rawState
case utf8 .RuneError :
if width != replacementcharacterwidth {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
}
func multilineCommentState(l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case '/' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune == '*' {
l .pos += width
l .nested ++
}
case '*' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune != '/' {
continue
}
l .pos += width
if l .nested == 0 {
return rawState
}
l .nested --
case utf8 .RuneError :
if width != replacementcharacterwidth {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
}
func SanitizeSQL (sql string , args ...any ) (string , error ) {
query , err := NewQuery (sql )
if err != nil {
return "" , err
}
return query .Sanitize (args ...)
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .