[docs]defload_mltk_model(model:str,test:bool=False,print_not_found_err=False,logger:logging.Logger=None,reload:bool=True)->MltkModel:"""Find a MLTK model with the given name and instantiate its corresponding :py:class:`mltk.core.MltkModel` object Args: model: Name of MLTK model or path to MLTK model's python specification script or archive. Append `-test` to the MLTK model name to load into "testing" mode (this is the same as setting the argument: test=True) test: If the MLTK model should be loaded in "testing" mode print_not_found_err: If true and the MLTK model is not found, then print an error reload: If the given model is a python script then reload the module if necessary Returns: Load model object """logger=loggerorget_mltk_logger()ifnotisinstance(model,str):raiseException('Model argument must be a string')# Resolve any path variables if applicablemodel_path=fullpath(model)# If a model file path was givenifos.path.exists(model_path):returnload_mltk_model_with_path(model_path,test=test,logger=logger,reload=reload)elifmodel_path.endswith('.mltk.zip'):raiseFileNotFoundError(f'Model archive not found: {model_path}')elifmodel_path.endswith(('.tflite','.h5')):raiseFileNotFoundError(f'Model file not found: {model_path}')elifmodel_path.endswith('.py'):raiseFileNotFoundError(f'Model specification script not found: {model_path}')# Otherwise the name of an MLTK model was given# So attempt to find it on the search paths# If '-test' was appended to the MLTK model name# then load the model in "test" modeifmodel.endswith('-test'):test=Truemodel=model[:-len('-test')]ifnotre.match(r'^[a-zA-Z0-9_]+$',model,re.DOTALL):raiseValueError('Invalid MLTK model argument given. Must either be the path to an existing model file (.tflite, .h5, .mltk.zip) or must contain only letters, numbers, or an underscore')logger.debug(f'Searching for MLTK model: {model}')model_spec_path=find_model_specification_file(model=model,test=test,logger=logger,print_not_found_err=print_not_found_err)ifnotmodel_spec_path:raiseException(f'Failed to find model specification file with name: {model}.py')returnload_mltk_model_with_path(model_path=model_spec_path,test=test,logger=logger)
[docs]defload_mltk_model_with_path(model_path:str,test:bool=False,logger:logging.Logger=None,reload:bool=True)->MltkModel:"""Instantiate a MltkModel object from the given model path The model path can be a ``.py`` model specificaton or a model archive ``.mltk.zip``. """ifnotmodel_path.endswith(('.py',ARCHIVE_EXTENSION)):raiseException('Model path must have either .mltk.zip or .py extension')logger=loggerorget_mltk_logger()# Resolve any path variables if applicablemodel_path=fullpath(model_path)model_base,_=os.path.splitext(model_path)model_base=model_base.replace('\\','/')model_name=os.path.basename(model_base)model_spec_path=f'{model_base}.py'# If the path to a model archive was given# then copy the archive to a temp directory# and extract the model specification file from the archiveifmodel_path.endswith(ARCHIVE_EXTENSION):model_name=os.path.basename(model_path).replace(TEST_ARCHIVE_EXTENSION,'').replace(ARCHIVE_EXTENSION,'')temp_dir=create_tempdir(f'tmp_model_specs/{model_name}')shutil.copy(model_path,temp_dir)logger.info(f'Extracting {model_name}.py from {model_path}')model_spec_path=extract_file(archive_path=f'{temp_dir}/{os.path.basename(model_path)}',name=f'{model_name}.py',dest_dir=temp_dir)ifmodel_path.endswith(TEST_ARCHIVE_EXTENSION):test=Truetry:logger.debug(f'Importing {model_spec_path}')model_module=import_module_at_path(model_spec_path,reload=reload)exceptExceptionase:prepend_exception_msg(e,f'Failed to import MLTK model module: {model_spec_path}')raiseforkeyindir(model_module):mltk_model=getattr(model_module,key)ifisinstance(mltk_model,MltkModel):mltk_version=_parse_version(mltk_version_str)# Issue a warning if the model's mltk version is different than the current mltk version# This can happen when a model archive is generated with an older version of the mltkmodel_mltk_version_str=getattr(model_module,'__mltk_version__',None)ifmodel_mltk_version_strisnotNone:model_mltk_version=_parse_version(model_mltk_version_str)ifmodel_mltk_version.major!=mltk_version.majorormodel_mltk_version.minor!=mltk_version.minor:logger.warning(f'Model {mltk_model.name} was created with mltk version: {model_mltk_version_str} but current mltk version is: {mltk_version_str}')iftest:mltk_model.enable_test_mode()# At this point, all of the model properties have been registered and populatedmltk_model.trigger_event(MltkModelEvent.AFTER_MODEL_LOAD)returnmltk_modelraiseException(f'Model specification file: {model_spec_path} does not define a MltkModel object')
[docs]defload_tflite_or_keras_model(model:Union[MltkModel,str],model_type:str=None,weights:str=None,logger:logging.Logger=None)->Union[TfliteModel,KerasModel]:"""Instantiate a Keras or TfliteModel object IF model is an :py:class:`mltk.core.MltkModel` instance OR a model archive `.mltk.zip`, AND model_type is: - ``None`` -> return built :py:class:`mltk.core.KerasModel` from model specification - ``tflite`` -> return loaded :py:class:`mltk.core.TfliteModel` from model archive - ``h5`` -> return loaded :py:class:`mltk.core.KerasModel` from model archive ELSE model should be the file path to a `.tflite` or `.h5` model file. """from.mixins.train_mixinimportTrainMixinlogger=loggerorget_mltk_logger()# Initialize the GPU if necessaryif(isinstance(model,MltkModel)andmodel_typeisNone) \
or(isinstance(model,str)andmodel.endswith(('.h5','.mltk.zip'))):gpu.initialize(logger=logger)ifisinstance(model,MltkModel)or(isinstance(model,str)andmodel.endswith('.mltk.zip')):ifisinstance(model,str)andmodel.endswith('.mltk.zip'):model=load_mltk_model(model)ifmodel_typeisNone:ifnotisinstance(model,TrainMixin):raiseException('MltkModel must inherit TrainMixin')logger.debug('Building Keras model')# Ensure test mode is enabledtest_mode_enabled=model.attributes['test_mode_enabled']model.attributes['test_mode_enabled']=Truebuilt_model=model.build_model_function(model)ifbuilt_modelisNone:raiseRuntimeError('Your "my_model.build_model_function" must return the compiled Keras model (did you forget to add the "return keras_model" statement at the end?')elifnotisinstance(built_model,KerasModel):raiseRuntimeError('Your "my_model.build_model_function" must return the compiled Keras model instance')on_save_model=getattr(model,'on_save_keras_model',None)ifon_save_modelisnotNone:try:built_model=on_save_model(mltk_model=model,keras_model=built_model,logger=logger)exceptExceptionase:prepend_exception_msg(e,'Error while calling my_model.on_save_keras_model')raise# Restore whatever the test mode state wasmodel.attributes['test_mode_enabled']=test_mode_enabledelifmodel_typein('h5','.h5','keras'):h5_path=model.h5_archive_pathtry:logger.debug(f'Loading Keras model from {model.archive_path}')built_model=load_keras_model(h5_path,custom_objects=model.keras_custom_objects)exceptExceptionase:prepend_exception_msg(e,'Failed to load Keras .h5 file')raiseelifmodel_typein('tflite','.tflite'):tflite_path=model.tflite_archive_pathtry:logger.debug(f'Loading .tflite model from {model.archive_path}')built_model=TfliteModel.load_flatbuffer_file(tflite_path)exceptExceptionase:prepend_exception_msg(e,'Failed to load .tflite file')raiseelse:raiseException('archive_extension must be h5, tflite or None')elifisinstance(model,str):ifmodel.endswith('.h5'):try:logger.debug(f'Loading Keras model from {model}')built_model=load_keras_model(model)exceptExceptionase:prepend_exception_msg(e,'Failed to load Keras .h5 file')raiseelifmodel.endswith('.tflite'):try:logger.debug(f'Loading .tflite model from {model}')built_model=TfliteModel.load_flatbuffer_file(model)exceptExceptionase:prepend_exception_msg(e,'Failed to load .tflite file')raiseelse:raiseException('Must provide path to .h5 or .tflite model file')else:raiseException('model must be a str or MltkModel')ifweights:ifisinstance(built_model,KerasModel):weights_file=weightsifisinstance(model,str)elsemodel.get_weights_path(weights)logger.info(f'Loading weights: {weights_file}')built_model.load_weights(weights_file)else:logger.warning('Loading weights into .tflite model not supported')returnbuilt_model
[docs]defload_tflite_model(model:Union[str,MltkModel,TfliteModel],build:bool=False,print_not_found_err:bool=False,return_tflite_path:bool=False,test:bool=False,logger:logging.Logger=None,archive_file_ext:str=None)->Union[TfliteModel,str]:"""Return the path to a .tflite model file or a TfliteModel instance Args: model: One of the following: - An MltkModel model instance - An TfliteModel model instance - The path to a .tflite - The path to a .mltk.zip model archive - The path to a .py MLTK model specification - The name of an MLTK model build: If the given Mltk model should be built into a .tflite print_not_found_err: If the model model is not found, print possible alternatives and exit return_tflite_path: If true, return the file path to the .tflite, otherwise return a TfliteModel instance test: If a "test" model is provided logger: Optional logger archive_file_ext: The extension of the .tflite model file in the mltk archive, e.g. .streaming.tflite This is only used if the "model" argument is the path to a .mltk.zip, the path to a .py MLTK model specification, or the name of an MLTK model Return: The corresponding TfliteModel if return_tflite_path=False or the path to the .tflite if return_tflite_path=True """logger=loggerorget_mltk_logger()mltk_model:MltkModel=Nonetflite_model:TfliteModel=Nonemodel_name=Noneifisinstance(model,MltkModel):mltk_model=modelmodel_name=mltk_model.namemodel=mltk_model.tflite_archive_pathifisinstance(model,TfliteModel):ifbuild:raiseRuntimeError('Cannot use build option with TfliteModel instance')tflite_model=modelmodel_name=(tflite_model.filenameor'my_model.tflite')[:-len('.tflite')]elifisinstance(model,str):ifbuildandmodel.endswith(('.tflite','.mltk.zip')):raiseRuntimeError('Cannot use --build option with .tflite or .mltk.zip model argument. Must be model name or path to model specification (.py)')elifmodel.endswith('.h5'):raiseValueError('Must provide .tflite or .mltk.zip model file type')ifmodel.endswith('.tflite'):model=fullpath(model)ifreturn_tflite_path:ifnotos.path.exists(model):raiseFileNotFoundError(f'tflite model path not found: {model}')returnmodeltflite_model=TfliteModel.load_flatbuffer_file(model)model_name=tflite_model.filename[:-len('.tflite')]elifnotmodel.endswith('.mltk.zip'):ifbuild:mltk_model=load_mltk_model(model,test=test,logger=logger,print_not_found_err=print_not_found_err)else:model_spec_path=find_model_specification_file(model=model,test=test,logger=logger,print_not_found_err=print_not_found_err)ifmodel_spec_pathisNone:raiseValueError(f'Failed to find model specification file with name: {model}.py')ifmodel.endswith('-test'):test=Truemodel=model_spec_path[:-len('.py')]iftest:model+='-test'model+='.mltk.zip'ifmodel.endswith('.mltk.zip'):model_name=os.path.basename(model[:-len('.mltk.zip')])ifarchive_file_ext:ifnotarchive_file_ext.startswith('.'):archive_file_ext='.'+archive_file_exttflite_name=f'{model_name}{archive_file_ext}'elifmodel_name.endswith('-test'):model_name=model_name[:-len('-test')]tflite_name=f'{model_name}.test.tflite'else:tflite_name=f'{model_name}.tflite'tflite_path=extract_file(model,tflite_name)ifreturn_tflite_path:returntflite_pathtflite_model=TfliteModel.load_flatbuffer_file(tflite_path)ifbuild:from..quantize_modelimportquantize_modelifmltk_modelisNone:raiseRuntimeError('Must provide MltkModel instance, name of MltkModel, other .py path to model specification to use the build option')logger.info('--build option provided, building model rather than using trained model')tflite_model=quantize_model(model=mltk_model,build=True,output='tflite_model')ifreturn_tflite_path:tflite_path=create_tempdir('tmp_models')+f'/{model_name}.tflite'tflite_model.save(tflite_path)returntflite_pathelse:asserttflite_modelisnotNonereturntflite_model
[docs]deflist_mltk_models(test:bool=False,for_utests=False,logger:logging.Logger=None)->List[str]:"""Return a list of all found MLTK model names"""logger=loggerorget_mltk_logger()found_models=[]search_dirs=_get_model_search_dirs()archive_ext=get_archive_extension(test=False)test_archive_ext=get_archive_extension(test=True)mltk_model_re=re.compile(r'.*\s@mltk_model\s.*')utest_disable_re=re.compile(r'.*\s@mltk_utest_disabled\s.*')def_process_python_file(py_path):retval=Falsewithopen(py_path,'r')asf:forlineinf:iffor_utestsandutest_disable_re.match(line):returnFalseifmltk_model_re.match(line):retval=Trueifnotfor_utests:breakreturnretvalforsearch_dirinsearch_dirs:forroot,_,filesinwalk_with_depth(search_dir,depth=5,followlinks=True):forfninfiles:iffn.endswith('.py'):try:p=f'{root}/{fn}'.replace('\\','/')if_process_python_file(p):found_models.append(fn[:-len('.py')])exceptExceptionase:logger.warning(f'Failed to process Python file: {p}, err: {e}')iftest:iffn.endswith(test_archive_ext):found_models.append(fn.replace(test_archive_ext,''))else:iffn.endswith(archive_ext)andnotfn.endswith(test_archive_ext):found_models.append(fn.replace(archive_ext,''))# Do NOT recurse into the CWDifsearch_dir==os.curdir:breakreturnsorted(set(found_models))
deffind_model_specification_file(model:str,test:bool=False,logger:logging.Logger=None,print_not_found_err:bool=False)->str:"""Given the model name, attempt to find its corresponding python specification file. The specification file could be in a model archive. """logger=loggerorget_mltk_logger()search_dirs=_get_model_search_dirs()cwd=fullpath(os.getcwd())ifmodel.endswith('-test'):test=Truemodel=model[:-len('-test')]ifmodel.endswith('.py'):model=fullpath(model)model_subdir=os.path.dirname(model)model_name,_=os.path.splitext(os.path.basename(model))py_path=Nonearchive_path=Noneifmodel_subdir:model_subdir=f'{model_subdir}/'archive_ext=get_archive_extension(test=test)model_path=f'/{model_subdir}{model_name}.py'model_arc_path=f'{model_subdir}/{model_name}{archive_ext}'logger.debug(f'Model search path(s): {",".join(search_dirs)}')forsearch_dirinsearch_dirs:ifpy_pathisnotNone:breakforroot,_,filesinos.walk(search_dir,followlinks=True):root=root.replace('\\','/')forfninfiles:file_path=f'{root}/{fn}'iffile_path.endswith(model_path):py_path=file_pathiffile_path.endswith(model_arc_path):archive_path=file_path# If the spec was found then break out of the loopifpy_pathisnotNone:break# Do NOT recurse into the CWDifsearch_dir==cwd:breakifpy_pathisNoneandarchive_pathisnotNone:logger.info(f'Extracting {model_name}.py from {archive_path}')py_path=extract_file(archive_path=archive_path,name=f'{model_name}.py',dest_dir=os.path.dirname(archive_path))ifnotpy_pathandprint_not_found_err:frommltk.cliimportprint_did_you_mean_error# pylint: disable=import-outside-toplevelall_models=list_mltk_models(test=test)print_did_you_mean_error('Failed to find model',model,all_models,and_exit=True)returnpy_pathdefpush_active_model(mltk_model:MltkModel):if'_active_model_stack'notinglobals():globals()['_active_model_stack']=[]globals()['_active_model_stack'].append(mltk_model)defpop_active_model()->MltkModel:_active_model_stack=globals().get('_active_model_stack',[])assertlen(_active_model_stack)>0,'No active model'return_active_model_stack.pop()defget_active_model()->MltkModel:_active_model_stack=globals().get('_active_model_stack',[])iflen(_active_model_stack)==0:returnNonereturn_active_model_stack[-1]deftrigger_model_event(event:MltkModelEvent,**kwargs):active_model=get_active_model()assertactive_modelisnotNone,'No active model'active_model.trigger_event(event,**kwargs)def_get_model_search_dirs()->List[str]:"""Return list of model search directories This populates the list as follows: - ~/.mltk/user_settings.yaml:model_paths - CWD - MLTK_MODEL_PATHS OS environment variable - mltk.models package directory """search_dirs=as_list(get_user_setting('model_paths'))# Include the CWD only if it's not the root of the mltk repocurdir=fullpath(os.getcwd())iffullpath(MLTK_ROOT_DIR)!=curdir:search_dirs.append(os.getcwd())env_paths=os.getenv('MLTK_MODEL_PATHS','')ifenv_paths:search_dirs.extend(env_paths.split(os.pathsep))search_dirs.append(os.path.dirname(mltk_models.__file__))search_dirs=[fullpath(x)forxinsearch_dirs]returnsearch_dirs_Version=collections.namedtuple('_Version',['major','minor','patch'])def_parse_version(version):toks=version.split('.')major=0iflen(toks)<1elseint(toks[0])minor=0iflen(toks)<2elseint(toks[1])patch=0iflen(toks)<3elseint(toks[2])return_Version(major,minor,patch)
Important: We use cookies only for functional and traffic analytics.
We DO NOT use cookies for any marketing purposes. By using our site you acknowledge you have read and understood our Cookie Policy.