package smb2

import (
	
	
	
	
	
	
	
	

	. 
	. 
)

// Negotiator contains options for func (*Dialer) Dial.
type Negotiator struct {
	RequireMessageSigning bool     // enforce signing?
	ClientGuid            [16]byte // if it's zero, generated by crypto/rand.
	SpecifiedDialect      uint16   // if it's zero, clientDialects is used. (See feature.go for more details)
}

func ( *Negotiator) () (*NegotiateRequest, error) {
	 := new(NegotiateRequest)

	if .RequireMessageSigning {
		.SecurityMode = SMB2_NEGOTIATE_SIGNING_REQUIRED
	} else {
		.SecurityMode = SMB2_NEGOTIATE_SIGNING_ENABLED
	}

	.Capabilities = clientCapabilities

	if .ClientGuid == zero {
		,  := rand.Read(.ClientGuid[:])
		if  != nil {
			return nil, &InternalError{.Error()}
		}
	} else {
		.ClientGuid = .ClientGuid
	}

	if .SpecifiedDialect != UnknownSMB {
		.Dialects = []uint16{.SpecifiedDialect}

		switch .SpecifiedDialect {
		case SMB202:
		case SMB210:
		case SMB300:
		case SMB302:
		case SMB311:
			 := &HashContext{
				HashAlgorithms: clientHashAlgorithms,
				HashSalt:       make([]byte, 32),
			}
			if ,  := rand.Read(.HashSalt);  != nil {
				return nil, &InternalError{.Error()}
			}

			 := &CipherContext{
				Ciphers: clientCiphers,
			}

			.Contexts = append(.Contexts, , )
		default:
			return nil, &InternalError{"unsupported dialect specified"}
		}
	} else {
		.Dialects = clientDialects

		 := &HashContext{
			HashAlgorithms: clientHashAlgorithms,
			HashSalt:       make([]byte, 32),
		}
		if ,  := rand.Read(.HashSalt);  != nil {
			return nil, &InternalError{.Error()}
		}

		 := &CipherContext{
			Ciphers: clientCiphers,
		}

		.Contexts = append(.Contexts, , )
	}

	return , nil
}

func ( *Negotiator) ( transport,  *account,  context.Context) (*conn, error) {
	 := &conn{
		t:                   ,
		outstandingRequests: newOutstandingRequests(),
		account:             ,
		rdone:               make(chan struct{}, 1),
		wdone:               make(chan struct{}, 1),
		write:               make(chan []byte, 1),
		werr:                make(chan error, 1),
	}

	go .runSender()
	go .runReciever()

:
	,  := .makeRequest()
	if  != nil {
		return nil, 
	}

	.CreditCharge = 1

	,  := .send(, )
	if  != nil {
		return nil, 
	}

	,  := .recv()
	if  != nil {
		return nil, 
	}

	,  := accept(SMB2_NEGOTIATE, )
	if  != nil {
		return nil, 
	}

	 := NegotiateResponseDecoder()
	if .IsInvalid() {
		return nil, &InvalidResponseError{"broken negotiate response format"}
	}

	if .DialectRevision() == SMB2 {
		.SpecifiedDialect = SMB210

		goto 
	}

	if .SpecifiedDialect != UnknownSMB && .SpecifiedDialect != .DialectRevision() {
		return nil, &InvalidResponseError{"unexpected dialect returned"}
	}

	.requireSigning = .RequireMessageSigning || .SecurityMode()&SMB2_NEGOTIATE_SIGNING_REQUIRED != 0
	.capabilities = clientCapabilities & .Capabilities()
	.dialect = .DialectRevision()
	.maxTransactSize = .MaxTransactSize()
	.maxReadSize = .MaxReadSize()
	.maxWriteSize = .MaxWriteSize()
	.sequenceWindow = 1

	// conn.gssNegotiateToken = r.SecurityBuffer()
	// conn.clientGuid = n.ClientGuid
	// copy(conn.serverGuid[:], r.ServerGuid())

	if .dialect != SMB311 {
		return , nil
	}

	// handle context for SMB311
	 := .NegotiateContextList()
	for  := .NegotiateContextCount();  > 0; -- {
		 := NegotiateContextDecoder()
		if .IsInvalid() {
			return nil, &InvalidResponseError{"broken negotiate context format"}
		}

		switch .ContextType() {
		case SMB2_PREAUTH_INTEGRITY_CAPABILITIES:
			 := HashContextDataDecoder(.Data())
			if .IsInvalid() {
				return nil, &InvalidResponseError{"broken hash context data format"}
			}

			 := .HashAlgorithms()

			if len() != 1 {
				return nil, &InvalidResponseError{"multiple hash algorithms"}
			}

			.preauthIntegrityHashId = [0]

			switch .preauthIntegrityHashId {
			case SHA512:
				 := sha512.New()
				.Write(.preauthIntegrityHashValue[:])
				.Write(.pkt)
				.Sum(.preauthIntegrityHashValue[:0])

				.Reset()
				.Write(.preauthIntegrityHashValue[:])
				.Write()
				.Sum(.preauthIntegrityHashValue[:0])
			default:
				return nil, &InvalidResponseError{"unknown hash algorithm"}
			}
		case SMB2_ENCRYPTION_CAPABILITIES:
			 := CipherContextDataDecoder(.Data())
			if .IsInvalid() {
				return nil, &InvalidResponseError{"broken cipher context data format"}
			}

			 := .Ciphers()

			if len() != 1 {
				return nil, &InvalidResponseError{"multiple cipher algorithms"}
			}

			.cipherId = [0]

			switch .cipherId {
			case AES128CCM:
			case AES128GCM:
			default:
				return nil, &InvalidResponseError{"unknown cipher algorithm"}
			}
		default:
			// skip unsupported context
		}

		 := .Next()

		if len() <  {
			 = nil
		} else {
			 = [:]
		}
	}

	return , nil
}

type requestResponse struct {
	msgId         uint64
	asyncId       uint64
	creditRequest uint16
	pkt           []byte // request packet
	ctx           context.Context
	recv          chan []byte
	err           error
}

type outstandingRequests struct {
	m        sync.Mutex
	requests map[uint64]*requestResponse
}

func newOutstandingRequests() *outstandingRequests {
	return &outstandingRequests{
		requests: make(map[uint64]*requestResponse, 0),
	}
}

func ( *outstandingRequests) ( uint64) (*requestResponse, bool) {
	.m.Lock()
	defer .m.Unlock()

	,  := .requests[]
	if ! {
		return nil, false
	}

	delete(.requests, )

	return , true
}

func ( *outstandingRequests) ( uint64,  *requestResponse) {
	.m.Lock()
	defer .m.Unlock()

	.requests[] = 
}

func ( *outstandingRequests) ( error) {
	.m.Lock()
	defer .m.Unlock()

	for ,  := range .requests {
		.err = 
		close(.recv)
	}
}

type conn struct {
	t transport

	session                   *session
	outstandingRequests       *outstandingRequests
	sequenceWindow            uint64
	dialect                   uint16
	maxTransactSize           uint32
	maxReadSize               uint32
	maxWriteSize              uint32
	requireSigning            bool
	capabilities              uint32
	preauthIntegrityHashId    uint16
	preauthIntegrityHashValue [64]byte
	cipherId                  uint16

	account *account

	rdone chan struct{}
	wdone chan struct{}
	write chan []byte
	werr  chan error

	m sync.Mutex

	err error

	// gssNegotiateToken []byte
	// serverGuid        [16]byte
	// clientGuid        [16]byte

	_useSession int32 // receiver use session?
}

func ( *conn) () bool {
	return atomic.LoadInt32(&._useSession) != 0
}

func ( *conn) () {
	atomic.StoreInt32(&._useSession, 1)
}

func ( *conn) () *time.Timer {
	return time.NewTimer(5 * time.Second)
}

func ( *conn) ( uint16,  Packet,  context.Context) ( []byte,  error) {
	,  := .send(, )
	if  != nil {
		return nil, 
	}

	,  := .recv()
	if  != nil {
		return nil, 
	}

	return accept(, )
}

func ( *conn) ( int,  context.Context) ( uint16,  int,  error) {
	if .capabilities&SMB2_GLOBAL_CAP_LARGE_MTU == 0 {
		 = 1
	} else {
		 = uint16((-1)/(64*1024) + 1)
	}

	, ,  := .account.loan(, )
	if  != nil {
		return , 0, 
	}
	if  {
		return , , nil
	}

	return , 64 * 1024 * int(), nil
}

func ( *conn) ( uint16) {
	.account.charge(, )
}

func ( *conn) ( Packet,  context.Context) ( *requestResponse,  error) {
	return .sendWith(, nil, )
}

func ( *conn) ( Packet,  *treeConn,  context.Context) ( *requestResponse,  error) {
	.m.Lock()
	defer .m.Unlock()

	if .err != nil {
		return nil, .err
	}

	select {
	case <-.Done():
		return nil, &ContextError{Err: .Err()}
	default:
		// do nothing
	}

	,  = .makeRequestResponse(, , )
	if  != nil {
		return nil, 
	}

	select {
	case .write <- .pkt:
		select {
		case  = <-.werr:
			if  != nil {
				.outstandingRequests.pop(.msgId)

				return nil, &TransportError{}
			}
		case <-.Done():
			.outstandingRequests.pop(.msgId)

			return nil, &ContextError{Err: .Err()}
		}
	case <-.Done():
		.outstandingRequests.pop(.msgId)

		return nil, &ContextError{Err: .Err()}
	}

	return , nil
}

func ( *conn) ( Packet,  *treeConn,  context.Context) ( *requestResponse,  error) {
	 := .Header()

	var  uint64

	if ,  := .(*CancelRequest); ! {
		 = .sequenceWindow

		 := .CreditCharge

		.sequenceWindow += uint64()
		if .CreditRequestResponse == 0 {
			.CreditRequestResponse = 
		}

		.CreditRequestResponse += .account.opening()
	}

	.MessageId = 

	 := .session

	if  != nil {
		.SessionId = .sessionId

		if  != nil {
			.TreeId = .treeId
		}
	}

	 := make([]byte, .Size())

	.Encode()

	if  != nil {
		if ,  := .(*SessionSetupRequest); ! {
			if .sessionFlags&SMB2_SESSION_FLAG_ENCRYPT_DATA != 0 || ( != nil && .shareFlags&SMB2_SHAREFLAG_ENCRYPT_DATA != 0) {
				,  = .encrypt()
				if  != nil {
					return nil, &InternalError{.Error()}
				}
			} else {
				if .sessionFlags&(SMB2_SESSION_FLAG_IS_GUEST|SMB2_SESSION_FLAG_IS_NULL) == 0 {
					 = .sign()
				}
			}
		}
	}

	 = &requestResponse{
		msgId:         ,
		creditRequest: .CreditRequestResponse,
		pkt:           ,
		ctx:           ,
		recv:          make(chan []byte, 1),
	}

	.outstandingRequests.set(, )

	return , nil
}

func ( *conn) ( *requestResponse) ([]byte, error) {
	select {
	case  := <-.recv:
		if .err != nil {
			return nil, .err
		}
		return , nil
	case <-.ctx.Done():
		.outstandingRequests.pop(.msgId)

		return nil, &ContextError{Err: .ctx.Err()}
	}
}

func ( *conn) () {
	for {
		select {
		case <-.wdone:
			return
		case  := <-.write:
			,  := .t.Write()

			.werr <- 
		}
	}
}

func ( *conn) () {
	var  error

	for {
		,  := .t.ReadSize()
		if  != nil {
			 = &TransportError{}

			goto 
		}

		 := make([]byte, )

		_,  = .t.Read()
		if  != nil {
			 = &TransportError{}

			goto 
		}

		 := .useSession()

		var  bool

		if  {
			, ,  = .tryDecrypt()
			if  != nil {
				logger.Println("skip:", )

				continue
			}

			 := PacketCodec()
			if  := .session;  != nil {
				if .sessionId != .SessionId() {
					logger.Println("skip:", &InvalidResponseError{"unknown session id"})

					continue
				}

				if ,  := .treeConnTables[.TreeId()];  {
					if .treeId != .TreeId() {
						logger.Println("skip:", &InvalidResponseError{"unknown tree id"})

						continue
					}
				}
			}
		}

		var  []byte

		for {
			 := PacketCodec()

			if  := .NextCommand();  != 0 {
				,  = [:], [:]
			} else {
				 = nil
			}

			if  {
				 = .tryVerify(, )
			}

			 = .tryHandle(, )
			if  != nil {
				logger.Println("skip:", )
			}

			if  == nil {
				break
			}

			 = 
		}
	}

:
	select {
	case <-.rdone:
		 = nil
	default:
		logger.Println("error:", )
	}

	.m.Lock()
	defer .m.Unlock()

	.outstandingRequests.shutdown()

	.err = 

	close(.wdone)
}

func accept( uint16,  []byte) ( []byte,  error) {
	 := PacketCodec()
	if  := .Command();  !=  {
		return nil, &InvalidResponseError{fmt.Sprintf("expected command: %v, got %v", , )}
	}

	 := NtStatus(.Status())

	switch  {
	case STATUS_SUCCESS:
		return .Data(), nil
	case STATUS_OBJECT_NAME_COLLISION:
		return nil, os.ErrExist
	case STATUS_OBJECT_NAME_NOT_FOUND, STATUS_OBJECT_PATH_NOT_FOUND:
		return nil, os.ErrNotExist
	case STATUS_ACCESS_DENIED, STATUS_CANNOT_DELETE:
		return nil, os.ErrPermission
	}

	switch  {
	case SMB2_SESSION_SETUP:
		if  == STATUS_MORE_PROCESSING_REQUIRED {
			return .Data(), nil
		}
	case SMB2_QUERY_INFO:
		if  == STATUS_BUFFER_OVERFLOW {
			return nil, &ResponseError{Code: uint32()}
		}
	case SMB2_IOCTL:
		if  == STATUS_BUFFER_OVERFLOW {
			if !IoctlResponseDecoder(.Data()).IsInvalid() {
				return .Data(), &ResponseError{Code: uint32()}
			}
		}
	case SMB2_READ:
		if  == STATUS_BUFFER_OVERFLOW {
			return nil, &ResponseError{Code: uint32()}
		}
	case SMB2_CHANGE_NOTIFY:
		if  == STATUS_NOTIFY_ENUM_DIR {
			return nil, &ResponseError{Code: uint32()}
		}
	}

	return nil, acceptError(uint32(), .Data())
}

func acceptError( uint32,  []byte) error {
	 := ErrorResponseDecoder()
	if .IsInvalid() {
		return &InvalidResponseError{"broken error response format"}
	}

	 := .ErrorData()

	if  := .ErrorContextCount();  != 0 {
		 := make([][]byte, )
		for  := range  {
			 := ErrorContextResponseDecoder()
			if .IsInvalid() {
				return &InvalidResponseError{"broken error context response format"}
			}

			[] = .ErrorContextData()

			 := .Next()

			if len() <  {
				return &InvalidResponseError{"broken error context response format"}
			}

			 = [:]
		}
		return &ResponseError{Code: , data: }
	}
	return &ResponseError{Code: , data: [][]byte{}}
}

func ( *conn) ( []byte) ([]byte, error, bool) {
	 := PacketCodec()
	if .IsInvalid() {
		 := TransformCodec()
		if .IsInvalid() {
			return nil, &InvalidResponseError{"broken packet header format"}, false
		}

		if .Flags() != Encrypted {
			return nil, &InvalidResponseError{"encrypted flag is not on"}, false
		}

		if .session == nil || .session.sessionId != .SessionId() {
			return nil, &InvalidResponseError{"unknown session id returned"}, false
		}

		,  := .session.decrypt()
		if  != nil {
			return nil, &InvalidResponseError{.Error()}, false
		}

		return , nil, true
	}

	return , nil, false
}

func ( *conn) ( []byte,  bool) error {
	 := PacketCodec()

	 := .MessageId()

	if  != 0xFFFFFFFFFFFFFFFF {
		if .Flags()&SMB2_FLAGS_SIGNED != 0 {
			if .session == nil || .session.sessionId != .SessionId() {
				return &InvalidResponseError{"unknown session id returned"}
			} else {
				if !.session.verify() {
					return &InvalidResponseError{"unverified packet returned"}
				}
			}
		} else {
			if .requireSigning && ! {
				if .session != nil {
					if .session.sessionFlags&(SMB2_SESSION_FLAG_IS_GUEST|SMB2_SESSION_FLAG_IS_NULL) == 0 {
						if .session.sessionId == .SessionId() {
							return &InvalidResponseError{"signing required"}
						}
					}
				}
			}
		}
	}

	return nil
}

func ( *conn) ( []byte,  error) error {
	 := PacketCodec()

	 := .MessageId()

	,  := .outstandingRequests.pop()
	switch {
	case !:
		return &InvalidResponseError{"unknown message id returned"}
	case  != nil:
		.err = 

		close(.recv)
	case NtStatus(.Status()) == STATUS_PENDING:
		.asyncId = .AsyncId()
		.account.charge(.CreditResponse(), .creditRequest)
		.outstandingRequests.set(, )
	default:
		.account.charge(.CreditResponse(), .creditRequest)

		.recv <- 
	}

	return nil
}