From d721eb3a5b9751c2d4f5e68543cce401ad109f87 Mon Sep 17 00:00:00 2001 From: Yandrik Date: Sun, 29 Oct 2023 20:36:25 +0100 Subject: [PATCH] feat: implemented live recording transcription --- .gitignore | 3 + .idea/.gitignore | 3 + .idea/fluesterpost.iml | 12 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 7 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + __pycache__/nn_model_manager.cpython-311.pyc | Bin 2646 -> 3705 bytes ...isper_webservice_interface.cpython-311.pyc | Bin 0 -> 4057 bytes main.py | 514 ++++++++++++++---- main.spec | 37 ++ nn_model_manager.py | 26 + openapitools.json | 7 + poetry.lock | 178 +++++- pyproject.toml | 8 +- whisper_webservice_interface.py | 84 +++ 16 files changed, 777 insertions(+), 122 deletions(-) create mode 100644 .gitignore create mode 100644 .idea/.gitignore create mode 100644 .idea/fluesterpost.iml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 __pycache__/whisper_webservice_interface.cpython-311.pyc create mode 100644 main.spec create mode 100644 openapitools.json create mode 100644 whisper_webservice_interface.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..77df241 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/build +/dist +/__pycache__ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/fluesterpost.iml b/.idea/fluesterpost.iml new file mode 100644 index 0000000..8b8c395 --- /dev/null +++ b/.idea/fluesterpost.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b772daa --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..606cade --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/__pycache__/nn_model_manager.cpython-311.pyc b/__pycache__/nn_model_manager.cpython-311.pyc index 9d02e678da17a9f2974a3a5cc767a2c31ceb884e..4e9c9f0b81d0ef4c1fb4e3e4df6d0cebbf3eff1b 100644 GIT binary patch delta 1184 zcmbVK&2JM&6rb6h{qow0ak2?X904hW7y?cxSc;OSR6-Sp3MxUAYKw^0%s4pm+R5x% zBB`wyDTfGEYNIV`j}`(-MGmR*2k0?2?6*KOaTlq%m8zDh9+#-f+@0D$RJxXCAb?y zCSpD@Bio9p*s7`8nyIlEv;suJa7h7x=on2f7NTIOhtoC03^H4O&kiwrR(a3ShE)>8 z8VRKk4y+)v9qTK=4C5di+c_O%tbTC^ZDnWCd&C+v5Q`@VyFW$BG_Q+T_4wHi) zmfWu&VC+2@UIDSl=22J_Vc5He9>5{*qO=O*URCPrm#jQa=WeiuwgfRvWaHIe-WD2I zX6Y_B-^rZN)C$0D=>W{N)~pK;ng4MsR)hsu=0*P}`|_n`_#xZO4sY_LgJNm$17yPiA*ZZJmF? zS$~1PmL!%v8d_lI(gx>47#F!7=urZ!LqxtP%#Fp##qf5yM3v2aTW8# z&hM|TPF7IQ=5uKs4$}cfPCsH06B})59G9Nhb;Nman8PUs-ZSN>G|H4Oytm5bXs<7% zbH3)@q{PDM?5r>SQ#$Rs__O#U>Sy9~EQ2w@#)rFUN)ge%Q|3q&f-zwVvWo_uM*-tyEoVAG@g##eO| xsiFQFI<_Oi#PRx2qBitdZRiRsnYapf034Zu_35u)ig4r_Y$W{`M!X#@`7ezF1}y*p delta 459 zcmY+9JxfC|7=`oR?{BNNRIQ+*6goKA!Nnr_3mim*EOKko+S1!XZWSqrONZi6kPMD4 ze(d%K_!pc+I=Q+zs*~?kiuENq$#YIPkk9$&ba3zc4#DYn*Llyp3}((hkgmucVT3`0 z(BxDa5GhoFUk@mtr#93`qgtd>Ju;|)8H6TR7<5cRxW%**_Hd2qvq}BWP#f78V+YwB zW0zTc-)1&7&-6`uVexg*m-fw2 zBt_YE+RW^}ee>q;&3kXYnf-kz6hM&Pyx)g?!2O*TDyemYS&wlD-9jo-xfDuslU$mg z`FzXVT+GG^t4tqN)V{MmY^(cYVY4t zAJ|eK#C{yW-Xve^lWbNSu6rhfYBSR%1o~_pHHbk=EN<@g{5Vwqq@*^&SPg3UI+_gO zMs?d&o-T$o@Cb+Q0{5W0Hn0VMj=j;AZ>TgXy3%OI(hOZmxt6)nZQ>Pjg^pUx`N*lpBw&C-;VoFPU!W4T_e zCkz$4UL$K|vX&d3H;5`*S{j>{lFpc}mtZSP^n_GZ5g}L;65w4w&%A1*--2I%$LA!Q zw*_0YJ+{~Osr+qk&aVpBk?mjE)VapiuY3^ImverL&ashf&Jrs;LU+M_510+9jJ#=4 zjW&|asbYQo>nInn11p>QSK7izz}Bpv3#uNP>-%#^5AH|Z$Q0-KYr-Wo&vEDy;vnn1 z%|Rcp9n`p7gB7feJL|3Rwq*P5hR=O>VKon`E6RoJ&}>y};H&mcg8!q^;$H$0NawJw zMwAG_?`FXv5zB~JGdR+JZY%;78AH>p7;87ioMf_E3tzLKM}}Y%>lRicQwvPaVkAOi zB%&oFcn*^VYev(jBjEC@& zEYta-5d#z>^O}j-5Q}15F-g=*{9pz`@qtj|kOoRZlr#cKK=@P>bx#Yt>$`#6LK?Es z<-pIt$vG~^+x#DWHfQs{;co=bq5Is(eRN-N`B<0BU3IxRliC%`_3UJP#z^D%f}-o% z)i@ZzFyqM-wH+qFTk-iB&CFmT&*Lc*lQ}JcWiTQpNhN_}nT2P8UKRFM0`^AV;sLOq zD#JSD4g1G>Q$|8bnSHTZD+(Fwdw}GT6FOObKC*7UyS(>d)5Bd$%1xj&uR{~$@?>hn z4b*n)ecttE6r!X}*F!yHM*XfwvoMqcnyKkvaXo=u!L*1g8W{|GpI}9Gr45Z>gVSZX z!A-7#Jx@K-NN9Kw8dop(f-6;GiO_YpT$1dfog%aV6P4X;Qip+wHzm?eTWs5EC2|zR zH2lp*@NXWKwxO1`C7b5_aLJ3BT8rCXE$%o}Y;LE|&cmgUuhCmVK=P+c&8YD$4$Qbb z_nG~v{mFZ`--E(h=wu;uGJpDOUz5{(_zweXzOI6=tAs>P?5hwJ-5XpBbr(Y2CB*sn zJ^X30W5;K0pSFF{e!G3?MJ_V);n zdc^)We2-3WK;3XHBp84vsz9+}OIlu^sE1oYg?$0QBQlT`SD?xBsUF~-16i>iWM2TC z9^120h1$*I*N>d@{(sO3py>tOeVd_swhGbg81X#^nvyNedMY}o(#A2m-Ii9Wa9gRR z-+B$tq0PO4ufi{bT)*N2^S%Ui2eE|{R^;r)QTksH7;6*=0kbKN&{LZ%)Pd+PI#o6f zR^`f}+<79>*?YEsez{dV(>dq@>#LOlV*ff*b z)RBlmr~j_-T|Q41x5dk99<8Xup!>T5HlblpD7usFKq_#}0EAEkUEE)uu4lb&Q%@GC=Kc#G1c2ink4%zpy`8J>E>&I84T*A zmqMLCv}7z>YyJZT{{a9Z&+dC%v9;~i_^-!5n)rAE07w8Ct6&6R&9fU|ZL_RaTjR>G z_IxYpe}Xe#xZvBlY<}t6=lJ#&!>{}*_siUe@85Xek#;``zfuVAUzxab*$E$Wq+@0i z*kHA_ze9M`A@=X`J&JHZ6CkW6x@j4(-2?{y4!9iLd~=)y`>>L11ot^X6+idf1#oy! z1&AD92VDGJ=<}f3k3&`!CJ2kKJmk@ zuCB#Xagx>JX{>5W<*IRJATl;IFgi9kG!_R&K7HowQ2Y$NS+Yln1HU`MFy$q=}ku%H4FJIz_JvT&q!gK#VdqL@}R zo%{eq!|zE}uf%nyGh`j-xUFnWEI zw%EOFu8gjQj~Bwn>uT~)%l4Zq$ugZrmccSCv0WVvyIyM{1NTTmRU*_wgc_ev@0PdU ze&Afs#t+GqVZe_JgC|#@>!a5Nwn(-BI*qQRu`HMIf^Btr(Re2HrVTZl!hPf|C{kaW z^#23)0ml{5QRmy9Cn%g}8IJB8*UuNvPZTd)DtY#D@e%^Eyq|qmTUN7w+R50*!_I|? zzx{;qPReq!*9xfxRw^DmdM|P3a!G)G#ep9`9{i$@@lSRgTuI!%%m#fjc=qx6FGkAz zm4u^CJ6g(7Q-v$(+N5l>G8yJ}a6KgiWVwZX?wRbfI{er3e;#41Gd}6ObH(|2!g(iA z7+0O?l#|XlBB& K#2p}J4*w75#@gEe literal 0 HcmV?d00001 diff --git a/main.py b/main.py index 073c5bf..2d2c024 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,10 @@ import os import pprint import traceback +import typing + +import requests.exceptions +import validators import utils import flet as ft @@ -10,30 +14,154 @@ from typing import DefaultDict import pygame import nn_model_manager as mm +import whisper_webservice_interface + +import wave +import sys +import pyaudio + + +# === TEMP === +import logging + +logging.basicConfig() +logging.getLogger("faster_whisper").setLevel(logging.DEBUG) + +# === END === + + +# globals +transcribe_ready: bool = False +recording: bool = False +rec_stream: pyaudio.Stream | None = None +sound_chunks = [] +recorded_audio = [] + +# AUDIO stuff +REC_CHUNK = 1024 +REC_FORMAT = pyaudio.paInt16 +REC_CHANNELS = 1 +REC_RATE = 16000 +REC_RECORD_SECONDS = 5 def main(page): pygame.mixer.init() - first_name = ft.Ref[ft.TextField]() - last_name = ft.Ref[ft.TextField]() - greetings = ft.Ref[ft.Column]() + # get audio device names + p = pyaudio.PyAudio() + + capture_devices = [(i, p.get_device_info_by_index(i)['name']) for i in range(p.get_device_count()) if + p.get_device_info_by_index(i)['maxInputChannels'] > 0] + + record_button = ft.Ref[ft.IconButton]() + mic_select = ft.Ref[ft.Dropdown]() file_tree = ft.Ref[ft.Column]() file_tree_empty_text = ft.Ref[ft.Text]() - - load_model_text = ft.Ref[ft.Text]() + + # mode select + current_mode_select = ft.Ref[ft.Dropdown]() + current_mode_info_text = ft.Ref[ft.Text]() + processing_spinner = ft.Ref[ft.ProgressRing]() + + # local model mode model_size_select = ft.Ref[ft.Dropdown]() model_device_select = ft.Ref[ft.Dropdown]() # model_bits_select = ft.Ref[ft.Dropdown]() model_load_unload_button = ft.Ref[ft.IconButton]() - model_loading_spinner = ft.Ref[ft.ProgressRing]() - + + # docker whisper webservice mode + whisper_webservice_url_input = ft.Ref[ft.TextField]() + transcribe_buttons: list[ft.Ref[ft.IconButton]] = [] - + output_text_container = ft.Ref[ft.Container]() output_text_col = ft.Ref[ft.Column]() + def transcribe(fileOrBytes: str | bytes): + print(f"DEBUG: trying to transcribe audio {fileOrBytes if isinstance(fileOrBytes, str) else f'with len {len(fileOrBytes)}'}") + + # === LOCAL MODEL CODE === + if current_mode_select.current.value == 'local': + if not mm.is_model_loaded() or (isinstance(fileOrBytes, str) and not fileOrBytes.endswith('.mp3')): + print("DEBUG: can't transcribe a non-MP3 file or while no model is loaded") + return + + print(f"DEBUG: starting transcription") + output_text_container.current.alignment = ft.alignment.center + output_text_col.current.controls = [ft.ProgressRing()] + + # set all transcribe buttons to disabled + for btn in transcribe_buttons: + btn.current.disabled = True + page.update() + + try: + if isinstance(fileOrBytes, str): + segments, info = mm.transcribe_from_file(fileOrBytes) + else: + segments, info = mm.transcribe_from_i16_audio(fileOrBytes) + + txt = '' + + for seg in segments: + txt += seg.text + '\n' + + output_text_container.current.alignment = ft.alignment.top_left + output_text_col.current.controls = [ft.Text(txt, selectable=True)] # TODO + + except Exception as e: + output_text_container.current.alignment = ft.alignment.center + output_text_col.current.controls = [ft.Text(f"Transcribing failed: {str(e)}")] # TODO + + finally: + # set all transcribe buttons to disabled + for btn in transcribe_buttons: + btn.current.disabled = False + page.update() + + # === WEBSERVICE MODE CODE === + elif current_mode_select.current.value == 'webservice': + url = whisper_webservice_url_input.current.value + print(f"DEBUG: starting web transcription") + if validators.url(url, simple_host=True): + + output_text_container.current.alignment = ft.alignment.center + output_text_col.current.controls = [ft.ProgressRing()] + # set all transcribe buttons to disabled + for btn in transcribe_buttons: + btn.current.disabled = True + page.update() + + try: + print(f'DEBUG: sending web request...') + code, text = whisper_webservice_interface.send_asr_request(url, fileOrBytes, task="transcribe") + except requests.exceptions.RequestException as e: + output_text_container.current.alignment = ft.alignment.center + print(f'web transcription failed: {str(e)}') + output_text_col.current.controls = \ + [ft.Text(f"HTTP Request to {url}/asr failed. Reason:\n{str(e)}")] + # set all transcribe buttons to enabled + for btn in transcribe_buttons: + btn.current.disabled = False + page.update() + return + + # set all transcribe buttons to enabled + for btn in transcribe_buttons: + btn.current.disabled = False + + if code == 200: + output_text_container.current.alignment = ft.alignment.top_left + output_text_col.current.controls = [ft.Text(text, selectable=True)] + else: + output_text_container.current.alignment = ft.alignment.center + output_text_col.current.controls = \ + [ft.Text(f"HTTP Request to {url}/asr failed ({code}):\n{text}")] + + page.update() + def generate_file_tree(path: str, tree_dict: dict | DefaultDict): if path[-1] == os.sep: path = path[:-1] @@ -67,7 +195,7 @@ def main(page): for file in tree_dict[utils.FILES_KEY]: control = [ft.Text(file)] - + if not file.endswith('.mp3'): continue @@ -99,49 +227,19 @@ def main(page): _button_ref = ft.Ref[ft.IconButton]() control.append(ft.IconButton(icon=ft.icons.PLAY_CIRCLE_OUTLINED, ref=_button_ref, - on_click=lambda _, f=full_file_path, r=_button_ref: start_playing(f, r))) - - def transcribe(filepath: str): - print(f"DEBUG: trying to transcribe file {filepath}") - if not mm.is_model_loaded() or not filepath.endswith('.mp3'): - return - - print(f"DEBUG: starting transcription") - output_text_container.current.alignment = ft.alignment.center - output_text_col.current.controls = [ft.ProgressRing()] - - # set all transcribe buttons to disabled - for btn in transcribe_buttons: - btn.current.disabled = True - page.update() - - try: - segments, info = mm.transcribe_from_file(filepath) - - txt = '' - - for seg in segments: - txt += seg.text + '\n' - - output_text_container.current.alignment = ft.alignment.top_left - output_text_col.current.controls = [ft.Text(txt, selectable=True)] # TODO - - except Exception as e: - output_text_container.current.alignment = ft.alignment.center - output_text_col.current.controls = [ft.Text(f"Transcribing failed: {str(e)}")] # TODO - - finally: - # set all transcribe buttons to disabled - for btn in transcribe_buttons: - btn.current.disabled = False - page.update() - - + on_click=lambda _, f=full_file_path, r=_button_ref: start_playing(f, r))) + transcribe_button_ref = ft.Ref[ft.IconButton]() - - control.append(ft.IconButton(icon=ft.icons.FORMAT_ALIGN_LEFT, disabled=not mm.is_model_loaded(), ref=transcribe_button_ref, + + # check enabled + enabled = (current_mode_select.current.value == 'local' and mm.is_model_loaded()) or ( + current_mode_select.current.value == 'webservice' and + validators.url(whisper_webservice_url_input.current.value, simple_host=True)) + + control.append(ft.IconButton(icon=ft.icons.FORMAT_ALIGN_LEFT, disabled=not enabled, + ref=transcribe_button_ref, on_click=lambda _, f=full_file_path: transcribe(f))) - + transcribe_buttons.append(transcribe_button_ref) files_controls.append(ft.Row(control)) @@ -155,15 +253,6 @@ def main(page): ] ) - def btn_click(e): - greetings.current.controls.append( - ft.Text(f"Hello, {first_name.current.value} {last_name.current.value}!") - ) - first_name.current.value = "" - last_name.current.value = "" - page.update() - first_name.current.focus() - def on_dialog_result(e: ft.FilePickerResultEvent): path = e.path if path: @@ -182,18 +271,58 @@ def main(page): page.update() except e: print("didn't work aaa") # TODO: fix - - def load_model(): - - load_model_text.current.value = 'Loading... This may take a while.' - - model_size_select.current.disabled = True - model_device_select.current.disabled = True - # model_bits_select.current.disabled = True - model_load_unload_button.current.disabled = True - model_loading_spinner.current.visible = True + + def mode_select(): + global transcribe_ready + if mm.is_model_loaded(): + print("BUG: cannot change mode while model is loaded!") + return + + next_mode = current_mode_select.current.value + if next_mode == 'local': + # enable model selects & loads + model_size_select.current.visible = True + model_device_select.current.visible = True + model_load_unload_button.current.visible = True + model_size_select.current.disabled = False + model_device_select.current.disabled = False + + whisper_webservice_url_input.current.visible = False + + for btn in transcribe_buttons: + btn.current.disabled = True + + set_transcribe_ready(False) + + elif next_mode == 'webservice': + # enable model selects & loads + model_size_select.current.visible = False + model_device_select.current.visible = False + model_load_unload_button.current.visible = False + model_size_select.current.disabled = True + model_device_select.current.disabled = True + model_load_unload_button.current.disabled = True + current_mode_info_text.current.value = 'Input the URL of the onerahmet/openai-whisper-asr-webservice docker container' + + whisper_webservice_url_input.current.visible = True + whisper_webservice_url_input.current.disabled = False + + on_url_input(None) + + + else: + raise Exception(f'BUG: Impossible mode {next_mode} received!') + page.update() - + page.client_storage.set('selected_mode', next_mode) + + def load_model(): + current_mode_info_text.current.value = 'Loading... This may take a while.' + + page.update() + + paralyze_ui() + try: mm.set_model( size=model_size_select.current.value or 'base', @@ -203,55 +332,148 @@ def main(page): except Exception as e: print(f"loading model failed. Exception: {str(e)}") print(traceback.format_exc()) - load_model_text.current.value = f'Loading failed. Reason:\n{str(e)}' - model_size_select.current.disabled = False - model_device_select.current.disabled = False - # model_bits_select.current.disabled = False - + current_mode_info_text.current.value = f'Loading failed. Reason:\n{str(e)}' + set_transcribe_ready(False) + # raise e - - model_loading_spinner.current.visible = False - model_load_unload_button.current.disabled = False - + + processing_spinner.current.visible = False + if mm.is_model_loaded(): - load_model_text.current.value = f'Loaded.' - model_load_unload_button.current.icon = ft.icons.CLOSE - model_load_unload_button.current.on_click = lambda _: unload_model() - + current_mode_info_text.current.value = f'Loaded.' + # if successful, save to shared preferences page.client_storage.set('model_size', model_size_select.current.value) page.client_storage.set('device_select', model_device_select.current.value) - + # set all transcribe buttons to enabled - for btn in transcribe_buttons: - btn.current.disabled = False - - page.update() - + set_transcribe_ready(True) + else: + set_transcribe_ready(False) + def unload_model(): - model_load_unload_button.current.disabled = True - # set all transcribe buttons to disabled - for btn in transcribe_buttons: - btn.current.disabled = True - - page.update() - + paralyze_ui() + if mm.is_model_loaded(): mm.unload_model() - - load_model_text.current.value = 'Select parameters, and then load transcription model.' - model_size_select.current.disabled = False - model_device_select.current.disabled = False - # model_bits_select.current.disabled = False + + set_transcribe_ready(False) + + def paralyze_ui(): + model_size_select.current.disabled = True + model_device_select.current.disabled = True + # model_bits_select.current.disabled = True + model_load_unload_button.current.disabled = True + processing_spinner.current.visible = True + current_mode_select.current.disabled = True + + model_load_unload_button.current.icon = ft.icons.CLOSE model_load_unload_button.current.disabled = False - model_load_unload_button.current.icon = ft.icons.START - model_load_unload_button.current.on_click = lambda _: load_model() - model_loading_spinner.current.visible = False + for btn in transcribe_buttons: + btn.current.disabled = True + model_load_unload_button.current.disabled = True page.update() - - - + + def set_transcribe_ready(rdy: bool): + global transcribe_ready + transcribe_ready = rdy + + if transcribe_ready: + for btn in transcribe_buttons: + btn.current.disabled = False + model_size_select.current.disabled = True + model_device_select.current.disabled = True + # model_bits_select.current.disabled = True + model_load_unload_button.current.disabled = True + processing_spinner.current.visible = False + model_load_unload_button.current.on_click = lambda _: unload_model() + + model_load_unload_button.current.icon = ft.icons.CLOSE + model_load_unload_button.current.disabled = False + + if mm.is_model_loaded(): + current_mode_select.current.disabled = True + else: + for btn in transcribe_buttons: + btn.current.disabled = True + model_size_select.current.disabled = False + model_device_select.current.disabled = False + # model_bits_select.current.disabled = False + model_load_unload_button.current.disabled = False + model_load_unload_button.current.icon = ft.icons.START + model_load_unload_button.current.on_click = lambda _: load_model() + processing_spinner.current.visible = False + current_mode_select.current.disabled = False + + page.update() + + def on_url_input(e): + url_value = whisper_webservice_url_input.current.value + # print(url_value) + + if validators.url(url_value, simple_host=True): + # print('valid') + page.client_storage.set('webservice_url', url_value) + # set all transcribe buttons to enabled + set_transcribe_ready(True) + else: + # print('invalid') + # set all transcribe buttons to disabled + set_transcribe_ready(False) + + page.update() + + print(tuple(page.client_storage.get('selected_mic'))) + + def toggle_recording(): + global recording + global rec_stream + global sound_chunks + global recorded_audio + + if recording: + print("Stopping recording...") + + rec_stream.stop_stream() + + while not rec_stream.is_stopped(): + pass # wait until stopped + + recorded_audio = b"".join(sound_chunks) + + transcribe(recorded_audio) + + recording = False + + # sound = pygame.mixer.Sound(buffer=recorded_audio) # doesn't work because sampling rate is wrong + + print("playing back recorded sound") + # sound.play() + else: + print("Starting Recording...") + recording = True + + sound_chunks = [] + + def cb(in_data, _frame_count, _time_info, _status): + sound_chunks.append(in_data) + print(_time_info) + return in_data, pyaudio.paContinue + + rec_stream = p.open( + format=REC_FORMAT, + channels=REC_CHANNELS, + rate=REC_RATE, + input=True, + frames_per_buffer=REC_CHUNK, + stream_callback=cb + ) + + rec_stream.start_stream() + + def find_recordingdevice_tuple_by_name(search_name: str) -> typing.Tuple[int, str] | None: + return next(((device_id, name) for device_id, name in capture_devices if name == search_name)) # set up file picker file_picker = ft.FilePicker(on_result=on_dialog_result) @@ -263,11 +485,29 @@ def main(page): ft.Divider() ) + mode = page.client_storage.get('selected_mode') if page.client_storage.contains_key('selected_mode') else 'local' + page.add( ft.ResponsiveRow([ ft.Container( ft.Column([ - ft.ElevatedButton("Add Folder", on_click=lambda _: file_picker.get_directory_path()), + ft.Row([ + ft.ElevatedButton("Add Folder", on_click=lambda _: file_picker.get_directory_path()), + ft.Container(expand=True), + ft.IconButton(ft.icons.RECORD_VOICE_OVER, ref=record_button, + on_click=lambda _: toggle_recording()), + ]), + ft.Dropdown( + ref=mic_select, + options=[ft.dropdown.Option(x[1]) for x in capture_devices], + value=page.client_storage.get('selected_mic')[1] if ( + page.client_storage.contains_key('selected_mic') and tuple( + page.client_storage.get('selected_mic')) in capture_devices) else capture_devices[0][1], + height=36, + content_padding=2, + on_change=lambda _: page.client_storage.set('selected_mic', find_recordingdevice_tuple_by_name( + mic_select.current.value)) if mic_select.current.value else None + ), ft.Column(ref=file_tree, scroll=ft.ScrollMode.ALWAYS, expand=True), # ft.ListView(ref=file_tree), ft.Text("No Folder Open Yet", style=ft.TextTheme.body_small, color="grey", @@ -275,21 +515,44 @@ def main(page): ], expand=True), expand=True, col=4), ft.Container(expand=True, content=ft.Column(expand=True, controls=[ ft.Column([ - ft.Text('Select parameters, and then load transcription model.', ref=load_model_text), + ft.Text( + 'Select parameters, and then load transcription model.' + if mode == 'local' + else 'Input the URL of the onerahmet/openai-whisper-asr-webservice docker container' + , ref=current_mode_info_text), ft.Row([ + ft.Dropdown( + ref=current_mode_select, + width=160, + hint_text='mode', + value=mode, + on_change=lambda _: mode_select(), + options=[ + ft.dropdown.Option('local'), + ft.dropdown.Option('webservice'), + ], + ), + + # === LOCAL MODE === ft.Dropdown( ref=model_size_select, width=100, hint_text='model size', - value=page.client_storage.get('model_size') if page.client_storage.contains_key('model_size') else 'base', - options=[ft.dropdown.Option(x) for x in mm.ModelSize.__args__], # __args__ is not perfect here. But works. + value=page.client_storage.get('model_size') if page.client_storage.contains_key( + 'model_size') else 'base', + options=[ft.dropdown.Option(x) for x in mm.ModelSize.__args__], + # __args__ is not perfect here. But works. + visible=mode == 'local', ), ft.Dropdown( ref=model_device_select, width=100, hint_text='device', - value=page.client_storage.get('device_select') if page.client_storage.contains_key('device_select') else 'auto', - options=[ft.dropdown.Option(x) for x in mm.Device.__args__] # __args__ is not perfect here. But works. + value=page.client_storage.get('device_select') if page.client_storage.contains_key( + 'device_select') else 'auto', + options=[ft.dropdown.Option(x) for x in mm.Device.__args__], + visible=mode == 'local', + # __args__ is not perfect here. But works. ), # ft.Dropdown( # ref=model_bits_select, @@ -297,23 +560,36 @@ def main(page): # hint_text='bits', # value='16bit', # options=[ft.dropdown.Option(x) for x in mm.ComputeType.__args__] # __args__ is not perfect here. But works. - #), + # ), ft.IconButton( icon=ft.icons.START, ref=model_load_unload_button, on_click=lambda _: load_model(), + visible=mode == 'local', ), - ft.ProgressRing(ref=model_loading_spinner, visible=False) + # === WEBSERVICE MODE === + ft.TextField( + ref=whisper_webservice_url_input, + visible=mode == 'webservice', + on_change=on_url_input, + hint_text='e.g. http://localhost:9000', + value=page.client_storage.get('webservice_url') if page.client_storage.contains_key( + 'webservice_url') else '', + ), + # TODO: question mark hint button about what the web service is + + # === GENERAL === + ft.ProgressRing(ref=processing_spinner, visible=False) ]) ]), - ft.Container(expand=True, padding=12, border=ft.border.all(2, 'grey'), + ft.Container(expand=True, padding=12, border=ft.border.all(2, 'grey'), alignment=ft.alignment.center, ref=output_text_container, content=ft.Column( [ft.Text('Nothing to see here!', text_align=ft.TextAlign.CENTER)], - ref=output_text_col, - expand=True, - scroll=ft.ScrollMode.ADAPTIVE)), + ref=output_text_col, + expand=True, + scroll=ft.ScrollMode.ADAPTIVE)), ]), col=8) ], expand=True), ) diff --git a/main.spec b/main.spec new file mode 100644 index 0000000..f69f861 --- /dev/null +++ b/main.spec @@ -0,0 +1,37 @@ +# -*- mode: python ; coding: utf-8 -*- + + +a = Analysis( + ['main.py'], + pathex=[], + binaries=[], + datas=[], + hiddenimports=[], + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + noarchive=False, +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.datas, + [], + name='main', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) diff --git a/nn_model_manager.py b/nn_model_manager.py index 68a2716..6c01787 100644 --- a/nn_model_manager.py +++ b/nn_model_manager.py @@ -1,5 +1,7 @@ +import io import threading +import numpy as np from faster_whisper import WhisperModel import faster_whisper from typing import Literal, Iterable, Tuple @@ -47,6 +49,30 @@ def is_model_loaded() -> bool: return _model is not None +def transcribe_from_i16_audio(audio: bytes) -> Tuple[Iterable[ + faster_whisper.transcribe.Segment], faster_whisper.transcribe.TranscriptionInfo] | None: + """ + Transcribe audio from an MP3 file. + Note that this can - and will - crash if you don't catch exceptions. + + If the model isn't loaded yet, this will return None. + Otherwise, it will return the raw transcription from `faster-whisper`. + """ + if not is_model_loaded(): + return None + + data = np.frombuffer(audio, dtype=np.int16) + + # Convert s16 to f32. + data = data.astype(np.float32) / 32768.0 + + global _model + segments, info = _model.transcribe(data, beam_size=5) + # transcribe, and throw all exceptions to application to handle + + return segments, info + + def transcribe_from_file(mp3_path: str) -> Tuple[Iterable[faster_whisper.transcribe.Segment], faster_whisper.transcribe.TranscriptionInfo] | None: """ Transcribe audio from an MP3 file. diff --git a/openapitools.json b/openapitools.json new file mode 100644 index 0000000..4053ae8 --- /dev/null +++ b/openapitools.json @@ -0,0 +1,7 @@ +{ + "$schema": "./node_modules/@openapitools/openapi-generator-cli/config.schema.json", + "spaces": 2, + "generator-cli": { + "version": "7.0.1" + } +} diff --git a/poetry.lock b/poetry.lock index 83e0305..3cf8ed7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,17 @@ # This file is automatically @generated by Poetry and should not be changed by hand. +[[package]] +name = "altgraph" +version = "0.17.4" +description = "Python graph (network) package" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "altgraph-0.17.4-py2.py3-none-any.whl", hash = "sha256:642743b4750de17e655e6711601b077bc6598dbfa3ba5fa2b2a35ce12b508dff"}, + {file = "altgraph-0.17.4.tar.gz", hash = "sha256:1b5afbb98f6c4dcadb2e2ae6ab9fa994bbb8c1d75f4fa96d340f9437ae454406"}, +] + [[package]] name = "annotated-types" version = "0.6.0" @@ -681,6 +693,21 @@ files = [ {file = "lit-17.0.3.tar.gz", hash = "sha256:e6049032462be1e2928686cbd4a6cc5b3c545d83ecd078737fe79412c1f3fcc1"}, ] +[[package]] +name = "macholib" +version = "1.16.3" +description = "Mach-O header analysis and editing" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "macholib-1.16.3-py2.py3-none-any.whl", hash = "sha256:0e315d7583d38b8c77e815b1ecbdbf504a8258d8b3e17b61165c6feb60d18f2c"}, + {file = "macholib-1.16.3.tar.gz", hash = "sha256:07ae9e15e8e4cd9a788013d81f5908b3609aa76f9b1421bae9c4d7606ec86a30"}, +] + +[package.dependencies] +altgraph = ">=0.17" + [[package]] name = "markupsafe" version = "2.1.3" @@ -1075,6 +1102,18 @@ files = [ {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, ] +[[package]] +name = "pefile" +version = "2023.2.7" +description = "Python PE parsing module" +category = "main" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "pefile-2023.2.7-py3-none-any.whl", hash = "sha256:da185cd2af68c08a6cd4481f7325ed600a88f6a813bad9dea07ab3ef73d8d8d6"}, + {file = "pefile-2023.2.7.tar.gz", hash = "sha256:82e6114004b3d6911c77c3953e3838654b04511b8b66e8583db70c65998017dc"}, +] + [[package]] name = "plumbum" version = "1.8.2" @@ -1133,6 +1172,30 @@ files = [ {file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"}, ] +[[package]] +name = "pyaudio" +version = "0.2.13" +description = "Cross-platform audio I/O with PortAudio" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "PyAudio-0.2.13-cp310-cp310-win32.whl", hash = "sha256:48e29537ea22ae2ae323eebe297bfb2683831cee4f20d96964e131f65ab2161d"}, + {file = "PyAudio-0.2.13-cp310-cp310-win_amd64.whl", hash = "sha256:87137cfd0ef8608a2a383be3f6996f59505e322dab9d16531f14cf542fa294f1"}, + {file = "PyAudio-0.2.13-cp311-cp311-win32.whl", hash = "sha256:13915faaa780e6bbbb6d745ef0e761674fd461b1b1b3f9c1f57042a534bfc0c3"}, + {file = "PyAudio-0.2.13-cp311-cp311-win_amd64.whl", hash = "sha256:59cc3cc5211b729c7854e3989058a145872cc58b1a7b46c6d4d88448a343d890"}, + {file = "PyAudio-0.2.13-cp37-cp37m-win32.whl", hash = "sha256:d294e3f85b2238649b1ff49ce3412459a8a312569975a89d14646536362d7576"}, + {file = "PyAudio-0.2.13-cp37-cp37m-win_amd64.whl", hash = "sha256:ff7f5e44ef51fe61da1e09c6f632f0b5808198edd61b363855cc7dd03bf4a8ac"}, + {file = "PyAudio-0.2.13-cp38-cp38-win32.whl", hash = "sha256:c6b302b048c054b7463936d8ba884b73877dc47012f3c94665dba92dd658ae04"}, + {file = "PyAudio-0.2.13-cp38-cp38-win_amd64.whl", hash = "sha256:1505d766ee718df6f5a18b73ac42307ba1cb4d2c0397873159254a34f67515d6"}, + {file = "PyAudio-0.2.13-cp39-cp39-win32.whl", hash = "sha256:eb128e4a6ea9b98d9a31f33c44978885af27dbe8ae53d665f8790cbfe045517e"}, + {file = "PyAudio-0.2.13-cp39-cp39-win_amd64.whl", hash = "sha256:910ef09225cce227adbba92622d4a3e3c8375117f7dd64039f287d9ffc0e02a1"}, + {file = "PyAudio-0.2.13.tar.gz", hash = "sha256:26bccc81e4243d1c0ff5487e6b481de6329fcd65c79365c267cef38f363a2b56"}, +] + +[package.extras] +test = ["numpy"] + [[package]] name = "pydantic" version = "2.4.2" @@ -1354,6 +1417,52 @@ files = [ [package.extras] plugins = ["importlib-metadata"] +[[package]] +name = "pyinstaller" +version = "6.1.0" +description = "PyInstaller bundles a Python application and all its dependencies into a single package." +category = "main" +optional = false +python-versions = "<3.13,>=3.8" +files = [ + {file = "pyinstaller-6.1.0-py3-none-macosx_10_13_universal2.whl", hash = "sha256:da78942d31c1911ea4abcd3ca3bd0c062af7f163a5e227fd18a359b61deda4ca"}, + {file = "pyinstaller-6.1.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:f63d2353537bac7bfeeaedbe5ac99f3be35daa290dd1ad1be90768acbf77e3d5"}, + {file = "pyinstaller-6.1.0-py3-none-manylinux2014_i686.whl", hash = "sha256:6e71d9f6f5a1e0f7523e8ebee1b76bb29538f64d863e3711c2b21033f499e2b9"}, + {file = "pyinstaller-6.1.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:bebf6f442bbe6343acaec873803510ee1930d026846a018f727da4e0690081f8"}, + {file = "pyinstaller-6.1.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:3c04963637481a3edf1eec64ab4c3fce098908f02fc472c11e73be7eedc08b95"}, + {file = "pyinstaller-6.1.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:4368e4eb9999ce32e3280330b3c26f175e0fa7fa13efb4d2dc4ade488ff6d7c2"}, + {file = "pyinstaller-6.1.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:041ab9311d08162356829bf47293a613c44dc9ace28846fb63098889c7383c5d"}, + {file = "pyinstaller-6.1.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:331f050e8f9e923bb6b50454acfc0547fd52092585c61eb5f2fc93de60703f13"}, + {file = "pyinstaller-6.1.0-py3-none-win32.whl", hash = "sha256:9e8b5bbc1bdf554ade1360e62e4959091430c3cc15ebfff3c28c8894fd1f312a"}, + {file = "pyinstaller-6.1.0-py3-none-win_amd64.whl", hash = "sha256:f9f5bcaef6122d93c54ee7a9ecb07eab5b81a7ebfb5cb99af2b2a6ff49eff62f"}, + {file = "pyinstaller-6.1.0-py3-none-win_arm64.whl", hash = "sha256:dd438afd2abb643f5399c0cb254a11c217c06782cb274a2911dd785f9f67fa9e"}, + {file = "pyinstaller-6.1.0.tar.gz", hash = "sha256:8f3d49c60f3344bf3d4a6d4258bda665dad185ab2b097341d3af2a6387c838ef"}, +] + +[package.dependencies] +altgraph = "*" +macholib = {version = ">=1.8", markers = "sys_platform == \"darwin\""} +packaging = ">=20.0" +pefile = {version = ">=2022.5.30", markers = "sys_platform == \"win32\""} +pyinstaller-hooks-contrib = ">=2021.4" +pywin32-ctypes = {version = ">=0.2.1", markers = "sys_platform == \"win32\""} +setuptools = ">=42.0.0" + +[package.extras] +hook-testing = ["execnet (>=1.5.0)", "psutil", "pytest (>=2.7.3)"] + +[[package]] +name = "pyinstaller-hooks-contrib" +version = "2023.10" +description = "Community maintained hooks for PyInstaller" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyinstaller-hooks-contrib-2023.10.tar.gz", hash = "sha256:4b4a998036abb713774cb26534ca06b7e6e09e4c628196017a10deb11a48747f"}, + {file = "pyinstaller_hooks_contrib-2023.10-py2.py3-none-any.whl", hash = "sha256:6dc1786a8f452941245d5bb85893e2a33632ebdcbc4c23eea41f2ee08281b0c0"}, +] + [[package]] name = "pypng" version = "0.20220715.0" @@ -1378,6 +1487,36 @@ files = [ {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, ] +[[package]] +name = "pysdl2" +version = "0.9.16" +description = "Python SDL2 bindings" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "PySDL2-0.9.16.tar.gz", hash = "sha256:1027406badbecdd30fe56e800a5a76ad7d7271a3aec0b7acf780ee26a00f2d40"}, +] + +[[package]] +name = "pysdl2-dll" +version = "2.28.4" +description = "Pre-built SDL2 binaries for PySDL2" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "pysdl2-dll-2.28.4.tar.gz", hash = "sha256:051e411ef93778d924a21c6e8fcfabd404bae5620fa49fa417e05b494a6a7dca"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-macosx_10_11_universal2.whl", hash = "sha256:1acff652e62f906109a6ca4874ff1e210eebb4989df651955c48add43f89c077"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-macosx_10_11_x86_64.whl", hash = "sha256:a35ab0f06b9e42ba12575b6960ad7ea013fc0f49e6935b4b53d66a0a06668eae"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-manylinux2014_i686.whl", hash = "sha256:6868f67b831053730c1d429076594e3b4db8522b779c51932b0ca003ae47b134"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-manylinux2014_x86_64.whl", hash = "sha256:d77f13a0f411abb3abd6d49f8b41c1373f72b86b1973236023dc37d563c2d0db"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:bba4abda0962025bff2ab0f17ff93f70f09fe706468460a4709533f5550c9bd5"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:142133f79236b809850e035c9a7fc77cd1098bdeb5f4edbba818a24f2aa6cf55"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-win32.whl", hash = "sha256:e417decf74d63cc3f5092385bdfb75cc7815d34b838992f09aff21c40ad27237"}, + {file = "pysdl2_dll-2.28.4-py2.py3-none-win_amd64.whl", hash = "sha256:667628a119e00f45aed279e480516ccc484c2f9a5d03c901dd1996c3af4c5840"}, +] + [[package]] name = "pywin32" version = "306" @@ -1402,6 +1541,18 @@ files = [ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, ] +[[package]] +name = "pywin32-ctypes" +version = "0.2.2" +description = "A (partial) reimplementation of pywin32 using ctypes/cffi" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pywin32-ctypes-0.2.2.tar.gz", hash = "sha256:3426e063bdd5fd4df74a14fa3cf80a0b42845a87e1d1e81f6549f9daec593a60"}, + {file = "pywin32_ctypes-0.2.2-py3-none-any.whl", hash = "sha256:bf490a1a709baf35d688fe0ecf980ed4de11d2b3e37b51e5442587a75d9957e7"}, +] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1873,6 +2024,29 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17. socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "validators" +version = "0.22.0" +description = "Python Data Validation for Humans™" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "validators-0.22.0-py3-none-any.whl", hash = "sha256:61cf7d4a62bbae559f2e54aed3b000cea9ff3e2fdbe463f51179b92c58c9585a"}, + {file = "validators-0.22.0.tar.gz", hash = "sha256:77b2689b172eeeb600d9605ab86194641670cdb73b60afd577142a9397873370"}, +] + +[package.extras] +docs-offline = ["myst-parser (>=2.0.0)", "pypandoc-binary (>=1.11)", "sphinx (>=7.1.1)"] +docs-online = ["mkdocs (>=1.5.2)", "mkdocs-git-revision-date-localized-plugin (>=1.2.0)", "mkdocs-material (>=9.2.6)", "mkdocstrings[python] (>=0.22.0)", "pyaml (>=23.7.0)"] +hooks = ["pre-commit (>=3.3.3)"] +package = ["build (>=1.0.0)", "twine (>=4.0.2)"] +runner = ["tox (>=4.11.1)"] +sast = ["bandit[toml] (>=1.7.5)"] +testing = ["pytest (>=7.4.0)"] +tooling = ["black (>=23.7.0)", "pyright (>=1.1.325)", "ruff (>=0.0.287)"] +tooling-extras = ["pyaml (>=23.7.0)", "pypandoc-binary (>=1.11)", "pytest (>=7.4.0)"] + [[package]] name = "watchdog" version = "3.0.0" @@ -2039,5 +2213,5 @@ test = ["pytest (>=6.0.0)", "setuptools (>=65)"] [metadata] lock-version = "2.0" -python-versions = "^3.11" -content-hash = "5adbe2b271f9a98bc456e1995fd743db377752eede94e42bcc5dced023d42757" +python-versions = ">=3.11, <3.13" +content-hash = "5757172c816b0e5b7863ffee379028af4ae9d77e6aaa9e3076830030ccdcc539" diff --git a/pyproject.toml b/pyproject.toml index a55208e..5d2ef03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,11 +7,17 @@ license = "MIT" readme = "README.md" [tool.poetry.dependencies] -python = "^3.11" +python = ">=3.11, <3.13" flet = "^0.10.3" faster-whisper = "^0.9.0" pygame = "^2.5.2" torch = "2.0.0" +requests = "^2.31.0" +validators = "^0.22.0" +pyinstaller = "^6.1.0" +pysdl2 = "^0.9.16" +pysdl2-dll = "^2.28.4" +pyaudio = "^0.2.13" [build-system] diff --git a/whisper_webservice_interface.py b/whisper_webservice_interface.py new file mode 100644 index 0000000..30e0f5b --- /dev/null +++ b/whisper_webservice_interface.py @@ -0,0 +1,84 @@ +from typing import Optional, Union, Dict, Any + +import requests + + +def send_asr_request(url: str, audio_file_path_or_bytes: str | bytes, task: Optional[str] = None, language: Optional[str] = None, + initial_prompt: Optional[str] = None, encode: Optional[bool] = None, + output: Optional[str] = None, word_timestamps: Optional[bool] = None) -> tuple[int, str]: + """ + Send a request to the ASR endpoint. + Returns the text represented by the audio file if everything worked out, + and a tuple of the form (status_code, response_text) otherwise + """ + endpoint = f"{url}/asr" + + params = { + "task": task, + "language": language, + "initial_prompt": initial_prompt, + "encode": encode, + "output": output, + "word_timestamps": word_timestamps + } + + params = {k: v for k, v in params.items() if v is not None} + + if isinstance(audio_file_path_or_bytes, str): + with open(audio_file_path_or_bytes, 'rb') as f: + audio_file = f.read() + else: + audio_file = audio_file_path_or_bytes + + files = { + 'audio_file': audio_file + } + + response = requests.post(endpoint, params=params, files=files) + + return response.status_code, response.text + + +def detect_language(url: str, audio_file_path: str, encode: Optional[bool] = None) -> Dict[str, Any] | tuple[int, str]: + """ + Send a request to the Detect Language endpoint. + Returns either a dictionary of the form {'detected_language': '', 'language_code': ''} if the request + was successful, or a tuple of the form (status_code, response_text) otherwise. + """ + endpoint = f"{url}/detect-language" + + params = { + "encode": encode + } + + params = {k: v for k, v in params.items() if v is not None} + + with open(audio_file_path, 'rb') as f: + audio_file = f.read() + + files = { + 'audio_file': audio_file + } + + response = requests.post(endpoint, params=params, files=files) + + if response.status_code == 200: + return response.json() + else: + return response.status_code, response.text + + +# Example usage +def main(): + url = "http://127.0.0.1:9000" # Replace with the actual URL of the webservice + audio_file_path = "/run/media/yannik/IC RECORDER/REC_FILE/Interview01/231021_1541.mp3" + + response = send_asr_request(url, audio_file_path, task="transcribe", language="en") + print(response) + + response = detect_language(url, audio_file_path) + print(response) + + +if __name__ == "__main__": + main()