Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions
from .utils import _convert_datetime_to_utc_int


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
Expand All @@ -36,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -80,14 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now_as_int() <\
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now_as_int() < self._token.expires_on

@classmethod
def _get_utc_now_as_int(cls):
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
current_utc_datetime_as_int = _convert_datetime_to_utc_int(current_utc_datetime)
return current_utc_datetime_as_int
return get_current_utc_as_int() < self._token.expires_on
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,27 @@
# license information.
# --------------------------------------------------------------------------
from asyncio import Condition, Lock
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Any
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions
from .utils import _convert_datetime_to_utc_int


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
:keyword token_refresher: The token refresher to provide capacity to fetch fresh token
:keyword token_refresher: The async token refresher to provide capacity to fetch fresh token
:raises: TypeError
"""

_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
def __init__(self, token: str, **kwargs: Any):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -36,25 +33,24 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""

if not self._token_refresher or not self._token_expiring():
return self._token

should_this_thread_refresh = False

with self._lock:
async with self._lock:

while self._token_expiring():
if self._some_thread_refreshing:
if self._is_currenttoken_valid():
return self._token

self._wait_till_inprogress_thread_finish_refreshing()
await self._wait_till_inprogress_thread_finish_refreshing()
else:
should_this_thread_refresh = True
self._some_thread_refreshing = True
Expand All @@ -63,34 +59,37 @@ def get_token(self):

if should_this_thread_refresh:
try:
newtoken = self._token_refresher() # pylint:disable=not-callable
newtoken = await self._token_refresher() # pylint:disable=not-callable

with self._lock:
async with self._lock:
self._token = newtoken
self._some_thread_refreshing = False
self._lock.notify_all()
except:
with self._lock:
async with self._lock:
self._some_thread_refreshing = False
self._lock.notify_all()

raise

return self._token

def _wait_till_inprogress_thread_finish_refreshing(self):
async def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.release()
self._lock.acquire()
await self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now_as_int() < self._token.expires_on
return get_current_utc_as_int() < self._token.expires_on

async def close(self) -> None:
pass

async def __aenter__(self):
return self

@classmethod
def _get_utc_now_as_int(cls):
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
current_utc_datetime_as_int = _convert_datetime_to_utc_int(current_utc_datetime)
return current_utc_datetime_as_int
async def __aexit__(self, *args):
await self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from msrest.serialization import TZ_UTC
from azure.core.credentials import AccessToken


def _convert_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())


def parse_connection_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
endpoint = None
Expand Down Expand Up @@ -43,6 +49,13 @@ def get_current_utc_time():
# type: () -> str
return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT"


def get_current_utc_as_int():
# type: () -> int
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
return _convert_datetime_to_utc_int(current_utc_datetime)


def create_access_token(token):
# type: (str) -> azure.core.credentials.AccessToken
"""Creates an instance of azure.core.credentials.AccessToken from a
Expand Down Expand Up @@ -71,6 +84,7 @@ def create_access_token(token):
except ValueError:
raise ValueError(token_parse_err_msg)


def get_authentication_policy(
endpoint, # type: str
credential, # type: TokenCredential or str
Expand Down Expand Up @@ -101,7 +115,3 @@ def get_authentication_policy(

raise TypeError("Unsupported credential: {}. Use an access token string to use HMACCredentialsPolicy"
"or a token credential from azure.identity".format(type(credential)))

def _convert_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import six
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.exceptions import HttpResponseError
from azure.core.async_paging import AsyncItemPaged

Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(

self._client = AzureCommunicationChatService(
self._endpoint,
authentication_policy=BearerTokenCredentialPolicy(self._credential),
authentication_policy=AsyncBearerTokenCredentialPolicy(self._credential),
sdk_moniker=SDK_MONIKER,
**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import six
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.async_paging import AsyncItemPaged

from .._shared.user_credential_async import CommunicationTokenCredential
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(

self._client = AzureCommunicationChatService(
endpoint,
authentication_policy=BearerTokenCredentialPolicy(self._credential),
authentication_policy=AsyncBearerTokenCredentialPolicy(self._credential),
sdk_moniker=SDK_MONIKER,
**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@
import pytest
import time


def _convert_datetime_to_utc_int(input):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
input_datetime_as_int = epoch - time.mktime(input.timetuple())
return input_datetime_as_int


credential = Mock()
credential.get_token = Mock(return_value=AccessToken(
"some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC))
))
async def mock_get_token():
return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC)))

credential = Mock(get_token=mock_get_token)


@pytest.mark.asyncio
async def test_create_chat_thread():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
import pytest
import time


def _convert_datetime_to_utc_int(input):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
input_datetime_as_int = epoch - time.mktime(input.timetuple())
return input_datetime_as_int

credential = Mock()
credential.get_token = Mock(return_value=AccessToken(
"some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC))
))

async def mock_get_token():
return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC)))

credential = Mock(get_token=mock_get_token)


@pytest.mark.asyncio
async def test_update_topic():
Expand Down