Blob Blame History Raw
// Copyright 2017 Microsoft Corporation
//
//  Licensed under the Apache License, Version 2.0 (the "License");
//  you may not use this file except in compliance with the License.
//  You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
//  Unless required by applicable law or agreed to in writing, software
//  distributed under the License is distributed on an "AS IS" BASIS,
//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//  See the License for the specific language governing permissions and
//  limitations under the License.

package azure

import (
	"errors"
	"fmt"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/Azure/go-autorest/autorest"
)

// DoRetryWithRegistration tries to register the resource provider in case it is unregistered.
// It also handles request retries
func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
	return func(s autorest.Sender) autorest.Sender {
		return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
			rr := autorest.NewRetriableRequest(r)
			for currentAttempt := 0; currentAttempt < client.RetryAttempts; currentAttempt++ {
				err = rr.Prepare()
				if err != nil {
					return resp, err
				}

				resp, err = autorest.SendWithSender(s, rr.Request(),
					autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
				)
				if err != nil {
					return resp, err
				}

				if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration {
					return resp, err
				}

				var re RequestError
				if strings.Contains(r.Header.Get("Content-Type"), "xml") {
					// XML errors (e.g. Storage Data Plane) only return the inner object
					err = autorest.Respond(resp, autorest.ByUnmarshallingXML(&re.ServiceError))
				} else {
					err = autorest.Respond(resp, autorest.ByUnmarshallingJSON(&re))
				}

				if err != nil {
					return resp, err
				}
				err = re

				if re.ServiceError != nil && re.ServiceError.Code == "MissingSubscriptionRegistration" {
					regErr := register(client, r, re)
					if regErr != nil {
						return resp, fmt.Errorf("failed auto registering Resource Provider: %s. Original error: %s", regErr, err)
					}
				}
			}
			return resp, err
		})
	}
}

func getProvider(re RequestError) (string, error) {
	if re.ServiceError != nil && len(re.ServiceError.Details) > 0 {
		return re.ServiceError.Details[0]["target"].(string), nil
	}
	return "", errors.New("provider was not found in the response")
}

func register(client autorest.Client, originalReq *http.Request, re RequestError) error {
	subID := getSubscription(originalReq.URL.Path)
	if subID == "" {
		return errors.New("missing parameter subscriptionID to register resource provider")
	}
	providerName, err := getProvider(re)
	if err != nil {
		return fmt.Errorf("missing parameter provider to register resource provider: %s", err)
	}
	newURL := url.URL{
		Scheme: originalReq.URL.Scheme,
		Host:   originalReq.URL.Host,
	}

	// taken from the resources SDK
	// with almost identical code, this sections are easier to mantain
	// It is also not a good idea to import the SDK here
	// https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L252
	pathParameters := map[string]interface{}{
		"resourceProviderNamespace": autorest.Encode("path", providerName),
		"subscriptionId":            autorest.Encode("path", subID),
	}

	const APIVersion = "2016-09-01"
	queryParameters := map[string]interface{}{
		"api-version": APIVersion,
	}

	preparer := autorest.CreatePreparer(
		autorest.AsPost(),
		autorest.WithBaseURL(newURL.String()),
		autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register", pathParameters),
		autorest.WithQueryParameters(queryParameters),
	)

	req, err := preparer.Prepare(&http.Request{})
	if err != nil {
		return err
	}
	req = req.WithContext(originalReq.Context())

	resp, err := autorest.SendWithSender(client, req,
		autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
	)
	if err != nil {
		return err
	}

	type Provider struct {
		RegistrationState *string `json:"registrationState,omitempty"`
	}
	var provider Provider

	err = autorest.Respond(
		resp,
		WithErrorUnlessStatusCode(http.StatusOK),
		autorest.ByUnmarshallingJSON(&provider),
		autorest.ByClosing(),
	)
	if err != nil {
		return err
	}

	// poll for registered provisioning state
	registrationStartTime := time.Now()
	for err == nil && (client.PollingDuration == 0 || (client.PollingDuration != 0 && time.Since(registrationStartTime) < client.PollingDuration)) {
		// taken from the resources SDK
		// https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L45
		preparer := autorest.CreatePreparer(
			autorest.AsGet(),
			autorest.WithBaseURL(newURL.String()),
			autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}", pathParameters),
			autorest.WithQueryParameters(queryParameters),
		)
		req, err = preparer.Prepare(&http.Request{})
		if err != nil {
			return err
		}
		req = req.WithContext(originalReq.Context())

		resp, err := autorest.SendWithSender(client, req,
			autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
		)
		if err != nil {
			return err
		}

		err = autorest.Respond(
			resp,
			WithErrorUnlessStatusCode(http.StatusOK),
			autorest.ByUnmarshallingJSON(&provider),
			autorest.ByClosing(),
		)
		if err != nil {
			return err
		}

		if provider.RegistrationState != nil &&
			*provider.RegistrationState == "Registered" {
			break
		}

		delayed := autorest.DelayWithRetryAfter(resp, originalReq.Context().Done())
		if !delayed && !autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Context().Done()) {
			return originalReq.Context().Err()
		}
	}
	if client.PollingDuration != 0 && !(time.Since(registrationStartTime) < client.PollingDuration) {
		return errors.New("polling for resource provider registration has exceeded the polling duration")
	}
	return err
}

func getSubscription(path string) string {
	parts := strings.Split(path, "/")
	for i, v := range parts {
		if v == "subscriptions" && (i+1) < len(parts) {
			return parts[i+1]
		}
	}
	return ""
}