package sanitize

import (
	
	
	
	
	
	
	
)

// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part any

type Query struct {
	Parts []Part
}

// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3

func ( *Query) ( ...any) (string, error) {
	 := make([]bool, len())
	 := &bytes.Buffer{}

	for ,  := range .Parts {
		var  string
		switch part := .(type) {
		case string:
			 = 
		case int:
			 :=  - 1
			if  >= len() {
				return "", fmt.Errorf("insufficient arguments")
			}
			 := []
			switch arg := .(type) {
			case nil:
				 = "null"
			case int64:
				 = strconv.FormatInt(, 10)
			case float64:
				 = strconv.FormatFloat(, 'f', -1, 64)
			case bool:
				 = strconv.FormatBool()
			case []byte:
				 = QuoteBytes()
			case string:
				 = QuoteString()
			case time.Time:
				 = .Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
			default:
				return "", fmt.Errorf("invalid arg type: %T", )
			}
			[] = true
		default:
			return "", fmt.Errorf("invalid Part type: %T", )
		}
		.WriteString()
	}

	for ,  := range  {
		if ! {
			return "", fmt.Errorf("unused argument: %d", )
		}
	}
	return .String(), nil
}

func ( string) (*Query, error) {
	 := &sqlLexer{
		src:     ,
		stateFn: rawState,
	}

	for .stateFn != nil {
		.stateFn = .stateFn()
	}

	 := &Query{Parts: .parts}

	return , nil
}

func ( string) string {
	return "'" + strings.ReplaceAll(, "'", "''") + "'"
}

func ( []byte) string {
	return `'\x` + hex.EncodeToString() + "'"
}

type sqlLexer struct {
	src     string
	start   int
	pos     int
	nested  int // multiline comment nesting level.
	stateFn stateFn
	parts   []Part
}

type stateFn func(*sqlLexer) stateFn

func rawState( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case 'e', 'E':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '\'' {
				.pos += 
				return escapeStringState
			}
		case '\'':
			return singleQuoteState
		case '"':
			return doubleQuoteState
		case '$':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if '0' <=  &&  <= '9' {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos-])
				}
				.start = .pos
				return placeholderState
			}
		case '-':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '-' {
				.pos += 
				return oneLineCommentState
			}
		case '/':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '*' {
				.pos += 
				return multilineCommentState
			}
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func singleQuoteState( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '\'':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '\'' {
				return rawState
			}
			.pos += 
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func doubleQuoteState( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '"':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '"' {
				return rawState
			}
			.pos += 
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState( *sqlLexer) stateFn {
	 := 0

	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		if '0' <=  &&  <= '9' {
			 *= 10
			 += int( - '0')
		} else {
			.parts = append(.parts, )
			.pos -= 
			.start = .pos
			return rawState
		}
	}
}

func escapeStringState( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '\\':
			_,  = utf8.DecodeRuneInString(.src[.pos:])
			.pos += 
		case '\'':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '\'' {
				return rawState
			}
			.pos += 
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func oneLineCommentState( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '\\':
			_,  = utf8.DecodeRuneInString(.src[.pos:])
			.pos += 
		case '\n', '\r':
			return rawState
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func multilineCommentState( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '/':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '*' {
				.pos += 
				.nested++
			}
		case '*':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '/' {
				continue
			}

			.pos += 
			if .nested == 0 {
				return rawState
			}
			.nested--

		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func ( string,  ...any) (string, error) {
	,  := NewQuery()
	if  != nil {
		return "", 
	}
	return .Sanitize(...)
}