// Copyright (c) 2015-2023 Jeevanandam M (jeeva@myjeeva.com)
// 2023 Segev Dagan (https://github.com/segevda)
// All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.

package resty

import (
	
	
	
	
	
	
	
	
	
	
)

var (
	ErrDigestBadChallenge    = errors.New("digest: challenge is bad")
	ErrDigestCharset         = errors.New("digest: unsupported charset")
	ErrDigestAlgNotSupported = errors.New("digest: algorithm is not supported")
	ErrDigestQopNotSupported = errors.New("digest: no supported qop in list")
	ErrDigestNoQop           = errors.New("digest: qop must be specified")
)

var hashFuncs = map[string]func() hash.Hash{
	"":                 md5.New,
	"MD5":              md5.New,
	"MD5-sess":         md5.New,
	"SHA-256":          sha256.New,
	"SHA-256-sess":     sha256.New,
	"SHA-512-256":      sha512.New,
	"SHA-512-256-sess": sha512.New,
}

type digestCredentials struct {
	username, password string
}

type digestTransport struct {
	digestCredentials
	transport http.RoundTripper
}

func ( *digestTransport) ( *http.Request) (*http.Response, error) {
	// Copy the request, so we don't modify the input.
	 := new(http.Request)
	* = *
	.Header = make(http.Header)
	for ,  := range .Header {
		.Header[] = 
	}

	// Fix http: ContentLength=xxx with Body length 0
	if .Body == nil {
		.ContentLength = 0
	} else if .GetBody != nil {
		var  error
		.Body,  = .GetBody()
		if  != nil {
			return nil, 
		}
	}

	// Make a request to get the 401 that contains the challenge.
	,  := .transport.RoundTrip()
	if  != nil || .StatusCode != http.StatusUnauthorized {
		return , 
	}
	 := .Header.Get(hdrWwwAuthenticateKey)
	if  == "" {
		return , ErrDigestBadChallenge
	}

	,  := parseChallenge()
	if  != nil {
		return , 
	}

	// Form credentials based on the challenge
	 := .newCredentials(, )
	,  := .authorize()
	if  != nil {
		return , 
	}
	 = .Body.Close()
	if  != nil {
		return nil, 
	}

	// Make authenticated request
	.Header.Set(hdrAuthorizationKey, )
	return .transport.RoundTrip()
}

func ( *digestTransport) ( *http.Request,  *challenge) *credentials {
	return &credentials{
		username:   .username,
		userhash:   .userhash,
		realm:      .realm,
		nonce:      .nonce,
		digestURI:  .URL.RequestURI(),
		algorithm:  .algorithm,
		sessionAlg: strings.HasSuffix(.algorithm, "-sess"),
		opaque:     .opaque,
		messageQop: .qop,
		nc:         0,
		method:     .Method,
		password:   .password,
	}
}

type challenge struct {
	realm     string
	domain    string
	nonce     string
	opaque    string
	stale     string
	algorithm string
	qop       string
	userhash  string
}

func parseChallenge( string) (*challenge, error) {
	const  = " \n\r\t"
	const  = `"`
	 := strings.Trim(, )
	if !strings.HasPrefix(, "Digest ") {
		return nil, ErrDigestBadChallenge
	}
	 = strings.Trim([7:], )
	 := strings.Split(, ", ")
	 := &challenge{}
	var  []string
	for  := range  {
		 = strings.SplitN([], "=", 2)
		if len() != 2 {
			return nil, ErrDigestBadChallenge
		}
		switch [0] {
		case "realm":
			.realm = strings.Trim([1], )
		case "domain":
			.domain = strings.Trim([1], )
		case "nonce":
			.nonce = strings.Trim([1], )
		case "opaque":
			.opaque = strings.Trim([1], )
		case "stale":
			.stale = [1]
		case "algorithm":
			.algorithm = [1]
		case "qop":
			.qop = strings.Trim([1], )
		case "charset":
			if strings.ToUpper(strings.Trim([1], )) != "UTF-8" {
				return nil, ErrDigestCharset
			}
		case "userhash":
			.userhash = strings.Trim([1], )
		default:
			return nil, ErrDigestBadChallenge
		}
	}
	return , nil
}

type credentials struct {
	username   string
	userhash   string
	realm      string
	nonce      string
	digestURI  string
	algorithm  string
	sessionAlg bool
	cNonce     string
	opaque     string
	messageQop string
	nc         int
	method     string
	password   string
}

func ( *credentials) () (string, error) {
	if ,  := hashFuncs[.algorithm]; ! {
		return "", ErrDigestAlgNotSupported
	}

	if  := .validateQop();  != nil {
		return "", 
	}

	,  := .resp()
	if  != nil {
		return "", 
	}

	 := make([]string, 0, 10)
	if .userhash == "true" {
		// RFC 7616 3.4.4
		.username = .h(fmt.Sprintf("%s:%s", .username, .realm))
		 = append(, fmt.Sprintf(`userhash=%s`, .userhash))
	}
	 = append(, fmt.Sprintf(`username="%s"`, .username))
	 = append(, fmt.Sprintf(`realm="%s"`, .realm))
	 = append(, fmt.Sprintf(`nonce="%s"`, .nonce))
	 = append(, fmt.Sprintf(`uri="%s"`, .digestURI))
	 = append(, fmt.Sprintf(`response="%s"`, ))
	 = append(, fmt.Sprintf(`algorithm=%s`, .algorithm))
	if .opaque != "" {
		 = append(, fmt.Sprintf(`opaque="%s"`, .opaque))
	}
	if .messageQop != "" {
		 = append(, fmt.Sprintf("qop=%s", .messageQop))
		 = append(, fmt.Sprintf("nc=%08x", .nc))
		 = append(, fmt.Sprintf(`cnonce="%s"`, .cNonce))
	}

	return fmt.Sprintf("Digest %s", strings.Join(, ", ")), nil
}

func ( *credentials) () error {
	// Currently only supporting auth quality of protection. TODO: add auth-int support
	// NOTE: cURL support auth-int qop for requests other than POST and PUT (i.e. w/o body) by hashing an empty string
	// is this applicable for resty? see: https://github.com/curl/curl/blob/307b7543ea1e73ab04e062bdbe4b5bb409eaba3a/lib/vauth/digest.c#L774
	if .messageQop == "" {
		return ErrDigestNoQop
	}
	 := strings.Split(.messageQop, ", ")
	var  bool
	for ,  := range  {
		if  == "auth" {
			 = true
			break
		}
	}
	if ! {
		return ErrDigestQopNotSupported
	}

	.messageQop = "auth"

	return nil
}

func ( *credentials) ( string) string {
	 := hashFuncs[.algorithm]
	 := ()
	_, _ = .Write([]byte()) // Hash.Write never returns an error
	return fmt.Sprintf("%x", .Sum(nil))
}

func ( *credentials) () (string, error) {
	.nc++

	 := make([]byte, 16)
	,  := io.ReadFull(rand.Reader, )
	if  != nil {
		return "", 
	}
	.cNonce = fmt.Sprintf("%x", )[:32]

	 := .ha1()
	 := .ha2()

	return .kd(, fmt.Sprintf("%s:%08x:%s:%s:%s",
		.nonce, .nc, .cNonce, .messageQop, )), nil
}

func ( *credentials) (,  string) string {
	return .h(fmt.Sprintf("%s:%s", , ))
}

// RFC 7616 3.4.2
func ( *credentials) () string {
	 := .h(fmt.Sprintf("%s:%s:%s", .username, .realm, .password))
	if .sessionAlg {
		return .h(fmt.Sprintf("%s:%s:%s", , .nonce, .cNonce))
	}

	return 
}

// RFC 7616 3.4.3
func ( *credentials) () string {
	// currently no auth-int support
	return .h(fmt.Sprintf("%s:%s", .method, .digestURI))
}