Blob Blame History Raw
package adal

// Copyright 2017 Microsoft Corporation
//
//  Licensed under the Apache License, Version 2.0 (the "License");
//  you may not use this file except in compliance with the License.
//  You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
//  Unless required by applicable law or agreed to in writing, software
//  distributed under the License is distributed on an "AS IS" BASIS,
//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//  See the License for the specific language governing permissions and
//  limitations under the License.

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha1"
	"crypto/x509"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"math"
	"net/http"
	"net/url"
	"os"
	"strings"
	"sync"
	"time"

	"github.com/Azure/go-autorest/autorest/date"
	"github.com/dgrijalva/jwt-go"
)

const (
	defaultRefresh = 5 * time.Minute

	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
	OAuthGrantTypeDeviceCode = "device_code"

	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
	OAuthGrantTypeClientCredentials = "client_credentials"

	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
	OAuthGrantTypeUserPass = "password"

	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
	OAuthGrantTypeRefreshToken = "refresh_token"

	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
	OAuthGrantTypeAuthorizationCode = "authorization_code"

	// metadataHeader is the header required by MSI extension
	metadataHeader = "Metadata"

	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"

	// the API version to use for the MSI endpoint
	msiAPIVersion = "2018-02-01"

	// the default number of attempts to refresh an MSI authentication token
	defaultMaxMSIRefreshAttempts = 5

	// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
	asMSIEndpointEnv = "MSI_ENDPOINT"

	// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
	asMSISecretEnv = "MSI_SECRET"

	// the API version to use for the App Service MSI endpoint
	appServiceAPIVersion = "2017-09-01"
)

// OAuthTokenProvider is an interface which should be implemented by an access token retriever
type OAuthTokenProvider interface {
	OAuthToken() string
}

// MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization.
type MultitenantOAuthTokenProvider interface {
	PrimaryOAuthToken() string
	AuxiliaryOAuthTokens() []string
}

// TokenRefreshError is an interface used by errors returned during token refresh.
type TokenRefreshError interface {
	error
	Response() *http.Response
}

// Refresher is an interface for token refresh functionality
type Refresher interface {
	Refresh() error
	RefreshExchange(resource string) error
	EnsureFresh() error
}

// RefresherWithContext is an interface for token refresh functionality
type RefresherWithContext interface {
	RefreshWithContext(ctx context.Context) error
	RefreshExchangeWithContext(ctx context.Context, resource string) error
	EnsureFreshWithContext(ctx context.Context) error
}

// TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh
type TokenRefreshCallback func(Token) error

// TokenRefresh is a type representing a custom callback to refresh a token
type TokenRefresh func(ctx context.Context, resource string) (*Token, error)

// Token encapsulates the access token used to authorize Azure requests.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
type Token struct {
	AccessToken  string `json:"access_token"`
	RefreshToken string `json:"refresh_token"`

	ExpiresIn json.Number `json:"expires_in"`
	ExpiresOn json.Number `json:"expires_on"`
	NotBefore json.Number `json:"not_before"`

	Resource string `json:"resource"`
	Type     string `json:"token_type"`
}

func newToken() Token {
	return Token{
		ExpiresIn: "0",
		ExpiresOn: "0",
		NotBefore: "0",
	}
}

// IsZero returns true if the token object is zero-initialized.
func (t Token) IsZero() bool {
	return t == Token{}
}

// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
	s, err := t.ExpiresOn.Float64()
	if err != nil {
		s = -3600
	}

	expiration := date.NewUnixTimeFromSeconds(s)

	return time.Time(expiration).UTC()
}

// IsExpired returns true if the Token is expired, false otherwise.
func (t Token) IsExpired() bool {
	return t.WillExpireIn(0)
}

// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
// from now, false otherwise.
func (t Token) WillExpireIn(d time.Duration) bool {
	return !t.Expires().After(time.Now().Add(d))
}

//OAuthToken return the current access token
func (t *Token) OAuthToken() string {
	return t.AccessToken
}

// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
}

// ServicePrincipalNoSecret represents a secret type that contains no secret
// meaning it is not valid for fetching a fresh token. This is used by Manual
type ServicePrincipalNoSecret struct {
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
// It only returns an error for the ServicePrincipalNoSecret type
func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
}

// MarshalJSON implements the json.Marshaler interface.
func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
	type tokenType struct {
		Type string `json:"type"`
	}
	return json.Marshal(tokenType{
		Type: "ServicePrincipalNoSecret",
	})
}

// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
type ServicePrincipalTokenSecret struct {
	ClientSecret string `json:"value"`
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
	v.Set("client_secret", tokenSecret.ClientSecret)
	return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
	type tokenType struct {
		Type  string `json:"type"`
		Value string `json:"value"`
	}
	return json.Marshal(tokenType{
		Type:  "ServicePrincipalTokenSecret",
		Value: tokenSecret.ClientSecret,
	})
}

// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
type ServicePrincipalCertificateSecret struct {
	Certificate *x509.Certificate
	PrivateKey  *rsa.PrivateKey
}

// SignJwt returns the JWT signed with the certificate's private key.
func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
	hasher := sha1.New()
	_, err := hasher.Write(secret.Certificate.Raw)
	if err != nil {
		return "", err
	}

	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))

	// The jti (JWT ID) claim provides a unique identifier for the JWT.
	jti := make([]byte, 20)
	_, err = rand.Read(jti)
	if err != nil {
		return "", err
	}

	token := jwt.New(jwt.SigningMethodRS256)
	token.Header["x5t"] = thumbprint
	x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
	token.Header["x5c"] = x5c
	token.Claims = jwt.MapClaims{
		"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
		"iss": spt.inner.ClientID,
		"sub": spt.inner.ClientID,
		"jti": base64.URLEncoding.EncodeToString(jti),
		"nbf": time.Now().Unix(),
		"exp": time.Now().Add(24 * time.Hour).Unix(),
	}

	signedString, err := token.SignedString(secret.PrivateKey)
	return signedString, err
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
	jwt, err := secret.SignJwt(spt)
	if err != nil {
		return err
	}

	v.Set("client_assertion", jwt)
	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
	return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
	return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
}

// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
	return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
	return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
}

// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
	Username string `json:"username"`
	Password string `json:"password"`
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
	v.Set("username", secret.Username)
	v.Set("password", secret.Password)
	return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
	type tokenType struct {
		Type     string `json:"type"`
		Username string `json:"username"`
		Password string `json:"password"`
	}
	return json.Marshal(tokenType{
		Type:     "ServicePrincipalUsernamePasswordSecret",
		Username: secret.Username,
		Password: secret.Password,
	})
}

// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
	ClientSecret      string `json:"value"`
	AuthorizationCode string `json:"authCode"`
	RedirectURI       string `json:"redirect"`
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
	v.Set("code", secret.AuthorizationCode)
	v.Set("client_secret", secret.ClientSecret)
	v.Set("redirect_uri", secret.RedirectURI)
	return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
	type tokenType struct {
		Type     string `json:"type"`
		Value    string `json:"value"`
		AuthCode string `json:"authCode"`
		Redirect string `json:"redirect"`
	}
	return json.Marshal(tokenType{
		Type:     "ServicePrincipalAuthorizationCodeSecret",
		Value:    secret.ClientSecret,
		AuthCode: secret.AuthorizationCode,
		Redirect: secret.RedirectURI,
	})
}

// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
	inner             servicePrincipalToken
	refreshLock       *sync.RWMutex
	sender            Sender
	customRefreshFunc TokenRefresh
	refreshCallbacks  []TokenRefreshCallback
	// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
	// Settings this to a value less than 1 will use the default value.
	MaxMSIRefreshAttempts int
}

// MarshalTokenJSON returns the marshalled inner token.
func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
	return json.Marshal(spt.inner.Token)
}

// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
	spt.refreshCallbacks = callbacks
}

// SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
	spt.customRefreshFunc = customRefreshFunc
}

// MarshalJSON implements the json.Marshaler interface.
func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
	return json.Marshal(spt.inner)
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
	// need to determine the token type
	raw := map[string]interface{}{}
	err := json.Unmarshal(data, &raw)
	if err != nil {
		return err
	}
	secret := raw["secret"].(map[string]interface{})
	switch secret["type"] {
	case "ServicePrincipalNoSecret":
		spt.inner.Secret = &ServicePrincipalNoSecret{}
	case "ServicePrincipalTokenSecret":
		spt.inner.Secret = &ServicePrincipalTokenSecret{}
	case "ServicePrincipalCertificateSecret":
		return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
	case "ServicePrincipalMSISecret":
		return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
	case "ServicePrincipalUsernamePasswordSecret":
		spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
	case "ServicePrincipalAuthorizationCodeSecret":
		spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
	default:
		return fmt.Errorf("unrecognized token type '%s'", secret["type"])
	}
	err = json.Unmarshal(data, &spt.inner)
	if err != nil {
		return err
	}
	// Don't override the refreshLock or the sender if those have been already set.
	if spt.refreshLock == nil {
		spt.refreshLock = &sync.RWMutex{}
	}
	if spt.sender == nil {
		spt.sender = sender()
	}
	return nil
}

// internal type used for marshalling/unmarshalling
type servicePrincipalToken struct {
	Token         Token                  `json:"token"`
	Secret        ServicePrincipalSecret `json:"secret"`
	OauthConfig   OAuthConfig            `json:"oauth"`
	ClientID      string                 `json:"clientID"`
	Resource      string                 `json:"resource"`
	AutoRefresh   bool                   `json:"autoRefresh"`
	RefreshWithin time.Duration          `json:"refreshWithin"`
}

func validateOAuthConfig(oac OAuthConfig) error {
	if oac.IsZero() {
		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
	}
	return nil
}

// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	if err := validateOAuthConfig(oauthConfig); err != nil {
		return nil, err
	}
	if err := validateStringParam(id, "id"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	if secret == nil {
		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
	}
	spt := &ServicePrincipalToken{
		inner: servicePrincipalToken{
			Token:         newToken(),
			OauthConfig:   oauthConfig,
			Secret:        secret,
			ClientID:      id,
			Resource:      resource,
			AutoRefresh:   true,
			RefreshWithin: defaultRefresh,
		},
		refreshLock:      &sync.RWMutex{},
		sender:           sender(),
		refreshCallbacks: callbacks,
	}
	return spt, nil
}

// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	if err := validateOAuthConfig(oauthConfig); err != nil {
		return nil, err
	}
	if err := validateStringParam(clientID, "clientID"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	if token.IsZero() {
		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
	}
	spt, err := NewServicePrincipalTokenWithSecret(
		oauthConfig,
		clientID,
		resource,
		&ServicePrincipalNoSecret{},
		callbacks...)
	if err != nil {
		return nil, err
	}

	spt.inner.Token = token

	return spt, nil
}

// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	if err := validateOAuthConfig(oauthConfig); err != nil {
		return nil, err
	}
	if err := validateStringParam(clientID, "clientID"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	if secret == nil {
		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
	}
	if token.IsZero() {
		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
	}
	spt, err := NewServicePrincipalTokenWithSecret(
		oauthConfig,
		clientID,
		resource,
		secret,
		callbacks...)
	if err != nil {
		return nil, err
	}

	spt.inner.Token = token

	return spt, nil
}

// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
// credentials scoped to the named resource.
func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	if err := validateOAuthConfig(oauthConfig); err != nil {
		return nil, err
	}
	if err := validateStringParam(clientID, "clientID"); err != nil {
		return nil, err
	}
	if err := validateStringParam(secret, "secret"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	return NewServicePrincipalTokenWithSecret(
		oauthConfig,
		clientID,
		resource,
		&ServicePrincipalTokenSecret{
			ClientSecret: secret,
		},
		callbacks...,
	)
}

// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	if err := validateOAuthConfig(oauthConfig); err != nil {
		return nil, err
	}
	if err := validateStringParam(clientID, "clientID"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	if certificate == nil {
		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
	}
	if privateKey == nil {
		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
	}
	return NewServicePrincipalTokenWithSecret(
		oauthConfig,
		clientID,
		resource,
		&ServicePrincipalCertificateSecret{
			PrivateKey:  privateKey,
			Certificate: certificate,
		},
		callbacks...,
	)
}

// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	if err := validateOAuthConfig(oauthConfig); err != nil {
		return nil, err
	}
	if err := validateStringParam(clientID, "clientID"); err != nil {
		return nil, err
	}
	if err := validateStringParam(username, "username"); err != nil {
		return nil, err
	}
	if err := validateStringParam(password, "password"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	return NewServicePrincipalTokenWithSecret(
		oauthConfig,
		clientID,
		resource,
		&ServicePrincipalUsernamePasswordSecret{
			Username: username,
			Password: password,
		},
		callbacks...,
	)
}

// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {

	if err := validateOAuthConfig(oauthConfig); err != nil {
		return nil, err
	}
	if err := validateStringParam(clientID, "clientID"); err != nil {
		return nil, err
	}
	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
		return nil, err
	}
	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
		return nil, err
	}
	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}

	return NewServicePrincipalTokenWithSecret(
		oauthConfig,
		clientID,
		resource,
		&ServicePrincipalAuthorizationCodeSecret{
			ClientSecret:      clientSecret,
			AuthorizationCode: authorizationCode,
			RedirectURI:       redirectURI,
		},
		callbacks...,
	)
}

// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) {
	return msiEndpoint, nil
}

// NOTE: this only indicates if the ASE environment credentials have been set
// which does not necessarily mean that the caller is authenticating via ASE!
func isAppService() bool {
	_, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
	_, asMSISecretEnvExists := os.LookupEnv(asMSISecretEnv)

	return asMSIEndpointEnvExists && asMSISecretEnvExists
}

// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions
func GetMSIAppServiceEndpoint() (string, error) {
	asMSIEndpoint, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)

	if asMSIEndpointEnvExists {
		return asMSIEndpoint, nil
	}
	return "", errors.New("MSI endpoint not found")
}

// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
func GetMSIEndpoint() (string, error) {
	if isAppService() {
		return GetMSIAppServiceEndpoint()
	}
	return GetMSIVMEndpoint()
}

// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, nil, callbacks...)
}

// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the clientID of specified user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, nil, callbacks...)
}

// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the azure resource id of user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, &identityResourceID, callbacks...)
}

func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, identityResourceID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
	if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	if userAssignedID != nil {
		if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
			return nil, err
		}
	}
	if identityResourceID != nil {
		if err := validateStringParam(*identityResourceID, "identityResourceID"); err != nil {
			return nil, err
		}
	}
	// We set the oauth config token endpoint to be MSI's endpoint
	msiEndpointURL, err := url.Parse(msiEndpoint)
	if err != nil {
		return nil, err
	}

	v := url.Values{}
	v.Set("resource", resource)
	// App Service MSI currently only supports token API version 2017-09-01
	if isAppService() {
		v.Set("api-version", appServiceAPIVersion)
	} else {
		v.Set("api-version", msiAPIVersion)
	}
	if userAssignedID != nil {
		v.Set("client_id", *userAssignedID)
	}
	if identityResourceID != nil {
		v.Set("mi_res_id", *identityResourceID)
	}
	msiEndpointURL.RawQuery = v.Encode()

	spt := &ServicePrincipalToken{
		inner: servicePrincipalToken{
			Token: newToken(),
			OauthConfig: OAuthConfig{
				TokenEndpoint: *msiEndpointURL,
			},
			Secret:        &ServicePrincipalMSISecret{},
			Resource:      resource,
			AutoRefresh:   true,
			RefreshWithin: defaultRefresh,
		},
		refreshLock:           &sync.RWMutex{},
		sender:                sender(),
		refreshCallbacks:      callbacks,
		MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
	}

	if userAssignedID != nil {
		spt.inner.ClientID = *userAssignedID
	}

	return spt, nil
}

// internal type that implements TokenRefreshError
type tokenRefreshError struct {
	message string
	resp    *http.Response
}

// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre tokenRefreshError) Error() string {
	return tre.message
}

// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre tokenRefreshError) Response() *http.Response {
	return tre.resp
}

func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
	return tokenRefreshError{message: message, resp: resp}
}

// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error {
	return spt.EnsureFreshWithContext(context.Background())
}

// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
	// must take the read lock when initially checking the token's expiration
	if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
		// take the write lock then check again to see if the token was already refreshed
		spt.refreshLock.Lock()
		defer spt.refreshLock.Unlock()
		if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
			return spt.refreshInternal(ctx, spt.inner.Resource)
		}
	}
	return nil
}

// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
	if spt.refreshCallbacks != nil {
		for _, callback := range spt.refreshCallbacks {
			err := callback(spt.inner.Token)
			if err != nil {
				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
			}
		}
	}
	return nil
}

// Refresh obtains a fresh token for the Service Principal.
// This method is safe for concurrent use.
func (spt *ServicePrincipalToken) Refresh() error {
	return spt.RefreshWithContext(context.Background())
}

// RefreshWithContext obtains a fresh token for the Service Principal.
// This method is safe for concurrent use.
func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
	spt.refreshLock.Lock()
	defer spt.refreshLock.Unlock()
	return spt.refreshInternal(ctx, spt.inner.Resource)
}

// RefreshExchange refreshes the token, but for a different resource.
// This method is safe for concurrent use.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
	return spt.RefreshExchangeWithContext(context.Background(), resource)
}

// RefreshExchangeWithContext refreshes the token, but for a different resource.
// This method is safe for concurrent use.
func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
	spt.refreshLock.Lock()
	defer spt.refreshLock.Unlock()
	return spt.refreshInternal(ctx, resource)
}

func (spt *ServicePrincipalToken) getGrantType() string {
	switch spt.inner.Secret.(type) {
	case *ServicePrincipalUsernamePasswordSecret:
		return OAuthGrantTypeUserPass
	case *ServicePrincipalAuthorizationCodeSecret:
		return OAuthGrantTypeAuthorizationCode
	default:
		return OAuthGrantTypeClientCredentials
	}
}

func isIMDS(u url.URL) bool {
	return isMSIEndpoint(u) == true || isASEEndpoint(u) == true
}

func isMSIEndpoint(endpoint url.URL) bool {
	msi, err := url.Parse(msiEndpoint)
	if err != nil {
		return false
	}
	return endpoint.Host == msi.Host && endpoint.Path == msi.Path
}

func isASEEndpoint(endpoint url.URL) bool {
	aseEndpoint, err := GetMSIAppServiceEndpoint()
	if err != nil {
		// app service environment isn't enabled
		return false
	}
	ase, err := url.Parse(aseEndpoint)
	if err != nil {
		return false
	}
	return endpoint.Host == ase.Host && endpoint.Path == ase.Path
}

func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
	if spt.customRefreshFunc != nil {
		token, err := spt.customRefreshFunc(ctx, resource)
		if err != nil {
			return err
		}
		spt.inner.Token = *token
		return spt.InvokeRefreshCallbacks(spt.inner.Token)
	}

	req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
	if err != nil {
		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
	}
	req.Header.Add("User-Agent", UserAgent())
	// Add header when runtime is on App Service or Functions
	if isASEEndpoint(spt.inner.OauthConfig.TokenEndpoint) {
		asMSISecret, _ := os.LookupEnv(asMSISecretEnv)
		req.Header.Add("Secret", asMSISecret)
	}
	req = req.WithContext(ctx)
	if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
		v := url.Values{}
		v.Set("client_id", spt.inner.ClientID)
		v.Set("resource", resource)

		if spt.inner.Token.RefreshToken != "" {
			v.Set("grant_type", OAuthGrantTypeRefreshToken)
			v.Set("refresh_token", spt.inner.Token.RefreshToken)
			// web apps must specify client_secret when refreshing tokens
			// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
			if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
				err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
				if err != nil {
					return err
				}
			}
		} else {
			v.Set("grant_type", spt.getGrantType())
			err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
			if err != nil {
				return err
			}
		}

		s := v.Encode()
		body := ioutil.NopCloser(strings.NewReader(s))
		req.ContentLength = int64(len(s))
		req.Header.Set(contentType, mimeTypeFormPost)
		req.Body = body
	}

	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
		req.Method = http.MethodGet
		req.Header.Set(metadataHeader, "true")
	}

	var resp *http.Response
	if isMSIEndpoint(spt.inner.OauthConfig.TokenEndpoint) && !MSIAvailable(ctx, spt.sender) {
		// return a TokenRefreshError here so that we don't keep retrying
		return newTokenRefreshError("the MSI endpoint is not available", nil)
	}
	if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
		resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
	} else {
		resp, err = spt.sender.Do(req)
	}
	if err != nil {
		// don't return a TokenRefreshError here; this will allow retry logic to apply
		return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
	}

	defer resp.Body.Close()
	rb, err := ioutil.ReadAll(resp.Body)

	if resp.StatusCode != http.StatusOK {
		if err != nil {
			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
		}
		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
	}

	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
	// but some transient failure happened during deserialization.  by returning a generic error
	// the retry logic will kick in (we don't retry on TokenRefreshError).

	if err != nil {
		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
	}
	if len(strings.Trim(string(rb), " ")) == 0 {
		return fmt.Errorf("adal: Empty service principal token received during refresh")
	}
	var token Token
	err = json.Unmarshal(rb, &token)
	if err != nil {
		return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
	}

	spt.inner.Token = token

	return spt.InvokeRefreshCallbacks(token)
}

// retry logic specific to retrieving a token from the IMDS endpoint
func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
	// copied from client.go due to circular dependency
	retries := []int{
		http.StatusRequestTimeout,      // 408
		http.StatusTooManyRequests,     // 429
		http.StatusInternalServerError, // 500
		http.StatusBadGateway,          // 502
		http.StatusServiceUnavailable,  // 503
		http.StatusGatewayTimeout,      // 504
	}
	// extra retry status codes specific to IMDS
	retries = append(retries,
		http.StatusNotFound,
		http.StatusGone,
		// all remaining 5xx
		http.StatusNotImplemented,
		http.StatusHTTPVersionNotSupported,
		http.StatusVariantAlsoNegotiates,
		http.StatusInsufficientStorage,
		http.StatusLoopDetected,
		http.StatusNotExtended,
		http.StatusNetworkAuthenticationRequired)

	// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance

	const maxDelay time.Duration = 60 * time.Second

	attempt := 0
	delay := time.Duration(0)

	// maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made
	if maxAttempts < 1 {
		maxAttempts = defaultMaxMSIRefreshAttempts
	}

	for attempt < maxAttempts {
		if resp != nil && resp.Body != nil {
			io.Copy(ioutil.Discard, resp.Body)
			resp.Body.Close()
		}
		resp, err = sender.Do(req)
		// we want to retry if err is not nil or the status code is in the list of retry codes
		if err == nil && !responseHasStatusCode(resp, retries...) {
			return
		}

		// perform exponential backoff with a cap.
		// must increment attempt before calculating delay.
		attempt++
		// the base value of 2 is the "delta backoff" as specified in the guidance doc
		delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
		if delay > maxDelay {
			delay = maxDelay
		}

		select {
		case <-time.After(delay):
			// intentionally left blank
		case <-req.Context().Done():
			err = req.Context().Err()
			return
		}
	}
	return
}

func responseHasStatusCode(resp *http.Response, codes ...int) bool {
	if resp != nil {
		for _, i := range codes {
			if i == resp.StatusCode {
				return true
			}
		}
	}
	return false
}

// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
	spt.inner.AutoRefresh = autoRefresh
}

// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
// refresh the token.
func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
	spt.inner.RefreshWithin = d
	return
}

// SetSender sets the http.Client used when obtaining the Service Principal token. An
// undecorated http.Client is used by default.
func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }

// OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
func (spt *ServicePrincipalToken) OAuthToken() string {
	spt.refreshLock.RLock()
	defer spt.refreshLock.RUnlock()
	return spt.inner.Token.OAuthToken()
}

// Token returns a copy of the current token.
func (spt *ServicePrincipalToken) Token() Token {
	spt.refreshLock.RLock()
	defer spt.refreshLock.RUnlock()
	return spt.inner.Token
}

// MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization.
type MultiTenantServicePrincipalToken struct {
	PrimaryToken    *ServicePrincipalToken
	AuxiliaryTokens []*ServicePrincipalToken
}

// PrimaryOAuthToken returns the primary authorization token.
func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
	return mt.PrimaryToken.OAuthToken()
}

// AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens.
func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
	tokens := make([]string, len(mt.AuxiliaryTokens))
	for i := range mt.AuxiliaryTokens {
		tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
	}
	return tokens
}

// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
	if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil {
		return fmt.Errorf("failed to refresh primary token: %v", err)
	}
	for _, aux := range mt.AuxiliaryTokens {
		if err := aux.EnsureFreshWithContext(ctx); err != nil {
			return fmt.Errorf("failed to refresh auxiliary token: %v", err)
		}
	}
	return nil
}

// RefreshWithContext obtains a fresh token for the Service Principal.
func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
	if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil {
		return fmt.Errorf("failed to refresh primary token: %v", err)
	}
	for _, aux := range mt.AuxiliaryTokens {
		if err := aux.RefreshWithContext(ctx); err != nil {
			return fmt.Errorf("failed to refresh auxiliary token: %v", err)
		}
	}
	return nil
}

// RefreshExchangeWithContext refreshes the token, but for a different resource.
func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
	if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil {
		return fmt.Errorf("failed to refresh primary token: %v", err)
	}
	for _, aux := range mt.AuxiliaryTokens {
		if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil {
			return fmt.Errorf("failed to refresh auxiliary token: %v", err)
		}
	}
	return nil
}

// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
	if err := validateStringParam(clientID, "clientID"); err != nil {
		return nil, err
	}
	if err := validateStringParam(secret, "secret"); err != nil {
		return nil, err
	}
	if err := validateStringParam(resource, "resource"); err != nil {
		return nil, err
	}
	auxTenants := multiTenantCfg.AuxiliaryTenants()
	m := MultiTenantServicePrincipalToken{
		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
	}
	primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
	if err != nil {
		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
	}
	m.PrimaryToken = primary
	for i := range auxTenants {
		aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
		if err != nil {
			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
		}
		m.AuxiliaryTokens[i] = aux
	}
	return &m, nil
}

// MSIAvailable returns true if the MSI endpoint is available for authentication.
func MSIAvailable(ctx context.Context, sender Sender) bool {
	// this cannot fail, the return sig is due to legacy reasons
	msiEndpoint, _ := GetMSIVMEndpoint()
	tempCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
	defer cancel()
	req, _ := http.NewRequestWithContext(tempCtx, http.MethodGet, msiEndpoint, nil)
	q := req.URL.Query()
	q.Add("api-version", msiAPIVersion)
	req.URL.RawQuery = q.Encode()
	_, err := sender.Do(req)
	return err == nil
}