package client
import (
"encoding/json"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/jcmturner/gokrb5/v8/iana/nametype"
"github.com/jcmturner/gokrb5/v8/krberror"
"github.com/jcmturner/gokrb5/v8/messages"
"github.com/jcmturner/gokrb5/v8/types"
)
type sessions struct {
Entries map [string ]*session
mux sync .RWMutex
}
func (s *sessions ) destroy () {
s .mux .Lock ()
defer s .mux .Unlock ()
for k , e := range s .Entries {
e .destroy ()
delete (s .Entries , k )
}
}
func (s *sessions ) update (sess *session ) {
s .mux .Lock ()
defer s .mux .Unlock ()
if i , ok := s .Entries [sess .realm ]; ok {
if i != sess {
i .mux .Lock ()
defer i .mux .Unlock ()
if i .cancel != nil {
i .cancel <- true
}
s .Entries [sess .realm ] = sess
return
}
}
s .Entries [sess .realm ] = sess
}
func (s *sessions ) get (realm string ) (*session , bool ) {
s .mux .RLock ()
defer s .mux .RUnlock ()
sess , ok := s .Entries [realm ]
return sess , ok
}
type session struct {
realm string
authTime time .Time
endTime time .Time
renewTill time .Time
tgt messages .Ticket
sessionKey types .EncryptionKey
sessionKeyExpiration time .Time
cancel chan bool
mux sync .RWMutex
}
type jsonSession struct {
Realm string
AuthTime time .Time
EndTime time .Time
RenewTill time .Time
SessionKeyExpiration time .Time
}
func (cl *Client ) addSession (tgt messages .Ticket , dep messages .EncKDCRepPart ) {
if strings .ToLower (tgt .SName .NameString [0 ]) != "krbtgt" {
return
}
realm := tgt .SName .NameString [len (tgt .SName .NameString )-1 ]
s := &session {
realm : realm ,
authTime : dep .AuthTime ,
endTime : dep .EndTime ,
renewTill : dep .RenewTill ,
tgt : tgt ,
sessionKey : dep .Key ,
sessionKeyExpiration : dep .KeyExpiration ,
}
cl .sessions .update (s )
cl .enableAutoSessionRenewal (s )
cl .Log ("TGT session added for %s (EndTime: %v)" , realm , dep .EndTime )
}
func (s *session ) update (tgt messages .Ticket , dep messages .EncKDCRepPart ) {
s .mux .Lock ()
defer s .mux .Unlock ()
s .authTime = dep .AuthTime
s .endTime = dep .EndTime
s .renewTill = dep .RenewTill
s .tgt = tgt
s .sessionKey = dep .Key
s .sessionKeyExpiration = dep .KeyExpiration
}
func (s *session ) destroy () {
s .mux .Lock ()
defer s .mux .Unlock ()
if s .cancel != nil {
s .cancel <- true
}
s .endTime = time .Now ().UTC ()
s .renewTill = s .endTime
s .sessionKeyExpiration = s .endTime
}
func (s *session ) valid () bool {
s .mux .RLock ()
defer s .mux .RUnlock ()
t := time .Now ().UTC ()
if t .Before (s .endTime ) && s .authTime .Before (t ) {
return true
}
return false
}
func (s *session ) tgtDetails () (string , messages .Ticket , types .EncryptionKey ) {
s .mux .RLock ()
defer s .mux .RUnlock ()
return s .realm , s .tgt , s .sessionKey
}
func (s *session ) timeDetails () (string , time .Time , time .Time , time .Time , time .Time ) {
s .mux .RLock ()
defer s .mux .RUnlock ()
return s .realm , s .authTime , s .endTime , s .renewTill , s .sessionKeyExpiration
}
func (s *sessions ) JSON () (string , error ) {
s .mux .RLock ()
defer s .mux .RUnlock ()
var js []jsonSession
keys := make ([]string , 0 , len (s .Entries ))
for k := range s .Entries {
keys = append (keys , k )
}
sort .Strings (keys )
for _ , k := range keys {
r , at , et , rt , kt := s .Entries [k ].timeDetails ()
j := jsonSession {
Realm : r ,
AuthTime : at ,
EndTime : et ,
RenewTill : rt ,
SessionKeyExpiration : kt ,
}
js = append (js , j )
}
b , err := json .MarshalIndent (js , "" , " " )
if err != nil {
return "" , err
}
return string (b ), nil
}
func (cl *Client ) enableAutoSessionRenewal (s *session ) {
var timer *time .Timer
s .mux .Lock ()
s .cancel = make (chan bool , 1 )
s .mux .Unlock ()
go func (s *session ) {
for {
s .mux .RLock ()
w := (s .endTime .Sub (time .Now ().UTC ()) * 5 ) / 6
s .mux .RUnlock ()
if w < 0 {
return
}
timer = time .NewTimer (w )
select {
case <- timer .C :
renewal , err := cl .refreshSession (s )
if err != nil {
cl .Log ("error refreshing session: %v" , err )
}
if !renewal && err == nil {
return
}
case <- s .cancel :
timer .Stop ()
return
}
}
}(s )
}
func (cl *Client ) renewTGT (s *session ) error {
realm , tgt , skey := s .tgtDetails ()
spn := types .PrincipalName {
NameType : nametype .KRB_NT_SRV_INST ,
NameString : []string {"krbtgt" , realm },
}
_ , tgsRep , err := cl .TGSREQGenerateAndExchange (spn , cl .Credentials .Domain (), tgt , skey , true )
if err != nil {
return krberror .Errorf (err , krberror .KRBMsgError , "error renewing TGT for %s" , realm )
}
s .update (tgsRep .Ticket , tgsRep .DecryptedEncPart )
cl .sessions .update (s )
cl .Log ("TGT session renewed for %s (EndTime: %v)" , realm , tgsRep .DecryptedEncPart .EndTime )
return nil
}
func (cl *Client ) refreshSession (s *session ) (bool , error ) {
s .mux .RLock ()
realm := s .realm
renewTill := s .renewTill
s .mux .RUnlock ()
cl .Log ("refreshing TGT session for %s" , realm )
if time .Now ().UTC ().Before (renewTill ) {
err := cl .renewTGT (s )
return true , err
}
err := cl .realmLogin (realm )
return false , err
}
func (cl *Client ) ensureValidSession (realm string ) error {
s , ok := cl .sessions .get (realm )
if ok {
s .mux .RLock ()
d := s .endTime .Sub (s .authTime ) / 6
if s .endTime .Sub (time .Now ().UTC ()) > d {
s .mux .RUnlock ()
return nil
}
s .mux .RUnlock ()
_ , err := cl .refreshSession (s )
return err
}
return cl .realmLogin (realm )
}
func (cl *Client ) sessionTGT (realm string ) (tgt messages .Ticket , sessionKey types .EncryptionKey , err error ) {
err = cl .ensureValidSession (realm )
if err != nil {
return
}
s , ok := cl .sessions .get (realm )
if !ok {
err = fmt .Errorf ("could not find TGT session for %s" , realm )
return
}
_, tgt , sessionKey = s .tgtDetails ()
return
}
func (cl *Client ) sessionTimes (realm string ) (authTime , endTime , renewTime , sessionExp time .Time , err error ) {
s , ok := cl .sessions .get (realm )
if !ok {
err = fmt .Errorf ("could not find TGT session for %s" , realm )
return
}
_, authTime , endTime , renewTime , sessionExp = s .timeDetails ()
return
}
func (cl *Client ) spnRealm (spn types .PrincipalName ) string {
return cl .Config .ResolveRealm (spn .NameString [len (spn .NameString )-1 ])
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .