Source code for mltk.utils.audio_dataset_generator.generator_types
from __future__ import annotations
import abc
import copy
from typing import List, NamedTuple
import time
import threading
import logging
from enum import Enum
import hashlib
from dataclasses import dataclass
logger = logging.getLogger('AudioDatasetGenerator')
[docs]class Keyword(NamedTuple):
"""Keyword to generate
"""
value:str
"""The base keyword
.. note:: If the base keyword starts with an underscore,
then it is NOT included in the list of keywords to generate (see :py:func:`~as_list`).
In this case, the :py:attr:`~aliases` must be provided.
"""
aliases:List[str]=None
"""Additional aliases for the base keyword"""
max_count:int=None
"""The maximum number of samples to generate for this keyword
This is only used if the ``truncate`` argument of :py:func:`~AudioDatasetGenerator.list_configurations` is ``true```
"""
[docs] def as_list(self) -> List[str]:
"""Return the base keyword and all its aliases as a list of strings
.. note:: If the base keyword (e.g. :py:attr:`~value`) starts with an underscore, then it is omitted from the list
"""
retval = []
if not self.value.startswith('_'):
retval.append(self.value)
if self.aliases:
for alias in self.aliases: # pylint:disable=not-an-iterable
retval.append(alias)
return retval
def __str__(self) -> str:
return self.value
[docs]class Augmentation(NamedTuple):
"""Augmentations to apply to the audio sample"""
pitch:VoicePitch=None
"""The pitch of the voice"""
rate:VoiceRate=None
"""The speaking rate of the voice"""
def __str__(self) -> str:
return f'Pitch:{self.pitch.value} Rate:{self.rate.value}'
[docs]class VoicePitch(str,Enum):
"""The "pitch" of the voice used in the audio sample"""
low = 'low'
medium = 'medium'
high = 'high'
default = 'medium'
[docs]class VoiceRate(str,Enum):
"""The speaking rate of the voice used in the audio sample"""
xslow = 'x-slow'
medium = 'medium'
xfast = 'x-fast'
default = 'medium'
[docs]@dataclass
class Voice:
"""The backend voice used to generate a keyword"""
name:str
"""The name of the voice as specified by the backend"""
language_code:str
"""The language code of the voice as specified by the backend"""
backend:str
"""The name of the voice's backend"""
[docs] def hashable_value(self) -> str:
"""The value used to generate a unique "hash" for the voice"""
return self.name + self.language_code + self.backend
@property
def hex_hash(self) -> str:
"""A unique hash for the voice
This may be used to group samples so that the same voice does not
appear in the "training" and "validation" subsets
"""
if not hasattr(self, '_hex_hash'):
hasher = hashlib.sha1()
hasher.update(self.hashable_value().encode('utf-8'))
setattr(self, '_hex_hash', hasher.hexdigest()[:8])
return getattr(self, '_hex_hash')
def __hash__(self):
return hash((self.name, self.language_code, self.backend))
[docs]@dataclass
class GenerationConfig:
"""Audio sample generation configuration"""
voice:Voice
"""The backend voice"""
rate:VoiceRate
"""The speaking rate"""
pitch:VoicePitch
"""The voice pitch"""
keyword:str=None
"""The keyword text (this is either the base keyword or keyword alias)"""
keyword_group:str=None
"""The base keyword"""
[docs] def copy(self) -> GenerationConfig:
"""Return a deep copy of the configuration"""
return copy.deepcopy(self)
class BackendBase(abc.ABC):
"""Base class for a cloud backend"""
def __init__(self, transactions_per_second:float):
super().__init__()
self._generate_timestamp:float = 0.0
self._min_seconds_per_transaction = 1/(transactions_per_second * 0.85)
self._lock = threading.Lock()
@property
def is_rate_limited(self) -> bool:
with self._lock:
now = time.time()
elapsed = now - self._generate_timestamp
return elapsed < self._min_seconds_per_transaction
def update_generate_timestamp(self):
with self._lock:
self._generate_timestamp = time.time()
@property
@abc.abstractproperty
def name(self) -> str:
...
@abc.abstractmethod
def load(self, install_python_package=False, **kwargs):
...
@abc.abstractmethod
def list_languages(self) -> List[str]:
...
@abc.abstractmethod
def list_voices(self, language_code:str=None) -> List[Voice]:
...
@abc.abstractmethod
def list_configurations(
self,
augmentations:List[Augmentation],
voice:Voice,
) -> List[GenerationConfig]:
...
@abc.abstractmethod
def count_characters(self, config:GenerationConfig) -> int:
...
@abc.abstractmethod
def generate(self, config:GenerationConfig, out_dir:str) -> str:
...
@abc.abstractmethod
def generate_filename(self, config:GenerationConfig) -> str:
...