fluesterpost/nn_model_manager.py
2023-10-23 17:49:26 +02:00

66 lines
1.9 KiB
Python

import threading
from faster_whisper import WhisperModel
import faster_whisper
from typing import Literal, Iterable, Tuple
_model: WhisperModel | None = None
ModelSize = Literal["tiny", "base", "small", "medium", "large-v1", "large-v2"]
Device = Literal["cuda", "cpu", "auto"]
ComputeType = Literal["8bit", "16bit", "32bit"]
def set_model(size: ModelSize, device: Device): #, compute_type: ComputeType):
'''
compute = None
if compute_type == '8bit':
if device == 'cuda' or device == 'auto':
compute = 'int8_float16'
else:
compute = 'int8'
elif compute_type == '16bit':
if device == 'cuda' or device == 'auto':
compute = 'int8'
else:
raise Exception("Cannot do 16 bit computing on CPU")
elif compute_type == '32bit':
compute = 'float'
else:
raise Exception(f"Invalid Compute / Device configuration (device {device} with {compute_type})")
'''
global _model
_model = WhisperModel(size, device=device)
def unload_model():
if not is_model_loaded():
return
global _model
_model = None # TODO: check if this works
def is_model_loaded() -> bool:
global _model
return _model is not None
def transcribe_from_file(mp3_path: str) -> Tuple[Iterable[faster_whisper.transcribe.Segment], faster_whisper.transcribe.TranscriptionInfo] | None:
"""
Transcribe audio from an MP3 file.
Note that this can - and will - crash if you don't catch exceptions.
If the model isn't loaded yet, this will return None.
Otherwise, it will return the raw transcription from `faster-whisper`.
"""
if not is_model_loaded():
return None
global _model
segments, info = _model.transcribe(mp3_path, beam_size=5)
# transcribe, and throw all exceptions to application to handle
return segments, info