fluesterpost/main.py

619 lines
23 KiB
Python

import os
import pprint
import traceback
import typing
import requests.exceptions
import validators
import utils
import flet as ft
from typing import DefaultDict
import pygame
import nn_model_manager as mm
import whisper_webservice_interface
import wave
import sys
import pyaudio
# === TEMP ===
import logging
logging.basicConfig()
logging.getLogger("faster_whisper").setLevel(logging.DEBUG)
# === END ===
# globals
transcribe_ready: bool = False
recording: bool = False
rec_stream: pyaudio.Stream | None = None
sound_chunks = []
recorded_audio = []
# AUDIO stuff
REC_CHUNK = 1024
REC_FORMAT = pyaudio.paInt16
REC_CHANNELS = 1
REC_RATE = 16000
REC_RECORD_SECONDS = 5
def main(page):
pygame.mixer.init()
# get audio device names
p = pyaudio.PyAudio()
capture_devices = [(i, p.get_device_info_by_index(i)['name']) for i in range(p.get_device_count()) if
p.get_device_info_by_index(i)['maxInputChannels'] > 0]
record_button = ft.Ref[ft.IconButton]()
mic_select = ft.Ref[ft.Dropdown]()
file_tree = ft.Ref[ft.Column]()
file_tree_empty_text = ft.Ref[ft.Text]()
# mode select
current_mode_select = ft.Ref[ft.Dropdown]()
current_mode_info_text = ft.Ref[ft.Text]()
processing_spinner = ft.Ref[ft.ProgressRing]()
# local model mode
model_size_select = ft.Ref[ft.Dropdown]()
model_device_select = ft.Ref[ft.Dropdown]()
# model_bits_select = ft.Ref[ft.Dropdown]()
model_load_unload_button = ft.Ref[ft.IconButton]()
# docker whisper webservice mode
whisper_webservice_url_input = ft.Ref[ft.TextField]()
transcribe_buttons: list[ft.Ref[ft.IconButton]] = []
output_text_container = ft.Ref[ft.Container]()
output_text_col = ft.Ref[ft.Column]()
def transcribe(fileOrBytes: str | bytes):
print(f"DEBUG: trying to transcribe audio {fileOrBytes if isinstance(fileOrBytes, str) else f'with len {len(fileOrBytes)}'}")
# === LOCAL MODEL CODE ===
if current_mode_select.current.value == 'local':
if not mm.is_model_loaded() or (isinstance(fileOrBytes, str) and not fileOrBytes.endswith('.mp3')):
print("DEBUG: can't transcribe a non-MP3 file or while no model is loaded")
return
print(f"DEBUG: starting transcription")
output_text_container.current.alignment = ft.alignment.center
output_text_col.current.controls = [ft.ProgressRing()]
# set all transcribe buttons to disabled
for btn in transcribe_buttons:
btn.current.disabled = True
page.update()
try:
if isinstance(fileOrBytes, str):
segments, info = mm.transcribe_from_file(fileOrBytes)
else:
segments, info = mm.transcribe_from_i16_audio(fileOrBytes)
txt = ''
for seg in segments:
txt += seg.text + '\n'
output_text_container.current.alignment = ft.alignment.top_left
output_text_col.current.controls = [ft.Text(txt, selectable=True)] # TODO
except Exception as e:
output_text_container.current.alignment = ft.alignment.center
output_text_col.current.controls = [ft.Text(f"Transcribing failed: {str(e)}")] # TODO
finally:
# set all transcribe buttons to disabled
for btn in transcribe_buttons:
btn.current.disabled = False
page.update()
# === WEBSERVICE MODE CODE ===
elif current_mode_select.current.value == 'webservice':
url = whisper_webservice_url_input.current.value
print(f"DEBUG: starting web transcription")
if validators.url(url, simple_host=True):
output_text_container.current.alignment = ft.alignment.center
output_text_col.current.controls = [ft.ProgressRing()]
# set all transcribe buttons to disabled
for btn in transcribe_buttons:
btn.current.disabled = True
page.update()
try:
print(f'DEBUG: sending web request...')
code, text = whisper_webservice_interface.send_asr_request(url, fileOrBytes, task="transcribe")
except requests.exceptions.RequestException as e:
output_text_container.current.alignment = ft.alignment.center
print(f'web transcription failed: {str(e)}')
output_text_col.current.controls = \
[ft.Text(f"HTTP Request to {url}/asr failed. Reason:\n{str(e)}")]
# set all transcribe buttons to enabled
for btn in transcribe_buttons:
btn.current.disabled = False
page.update()
return
# set all transcribe buttons to enabled
for btn in transcribe_buttons:
btn.current.disabled = False
if code == 200:
output_text_container.current.alignment = ft.alignment.top_left
output_text_col.current.controls = [ft.Text(text, selectable=True)]
else:
output_text_container.current.alignment = ft.alignment.center
output_text_col.current.controls = \
[ft.Text(f"HTTP Request to {url}/asr failed ({code}):\n{text}")]
page.update()
def generate_file_tree(path: str, tree_dict: dict | DefaultDict):
if path[-1] == os.sep:
path = path[:-1]
folder_name = utils.get_last_segment(path)
print(f"DEBUG: generating tree for folder {folder_name}")
# find folders, and add dict for each
print(f"adding name {folder_name} to ui")
controls = [
ft.Row(
[
ft.Icon(ft.icons.FOLDER, color=ft.colors.BLUE),
ft.Text(folder_name, size=14, weight=ft.FontWeight.BOLD),
]
)
]
for folder_name, value in tree_dict.items():
if folder_name == utils.FILES_KEY or folder_name == '.':
continue # skip for now
controls.append(generate_file_tree(path + os.sep + folder_name, value))
# now folders are there, let's do files
if utils.FILES_KEY not in tree_dict and '.' in tree_dict:
tree_dict = tree_dict['.'] # if root dir, enter root dir (.) directory
files_controls = []
for file in tree_dict[utils.FILES_KEY]:
control = [ft.Text(file)]
if not file.endswith('.mp3'):
continue
def start_playing(filepath: str, button_ref: ft.Ref[ft.IconButton]):
print(f"trying to play {filepath}...")
if pygame.mixer.music.get_busy() or not os.path.isfile(filepath):
return
print("starting playback")
pygame.mixer.music.load(filepath)
pygame.mixer.music.play()
button_ref.current.icon = ft.icons.PAUSE_CIRCLE_FILLED_OUTLINED
button_ref.current.on_click = lambda _, f=filepath, r=button_ref: stop_playing(f, r)
page.update()
def stop_playing(filepath: str, button_ref: ft.Ref[ft.IconButton]):
print("stopping playback")
pygame.mixer.music.stop()
button_ref.current.icon = ft.icons.PLAY_CIRCLE_OUTLINED
button_ref.current.on_click = lambda _, f=filepath, r=button_ref: start_playing(f, r)
page.update()
full_file_path = path + os.sep + file
_button_ref = ft.Ref[ft.IconButton]()
control.append(ft.IconButton(icon=ft.icons.PLAY_CIRCLE_OUTLINED, ref=_button_ref,
on_click=lambda _, f=full_file_path, r=_button_ref: start_playing(f, r)))
transcribe_button_ref = ft.Ref[ft.IconButton]()
# check enabled
enabled = (current_mode_select.current.value == 'local' and mm.is_model_loaded()) or (
current_mode_select.current.value == 'webservice' and
validators.url(whisper_webservice_url_input.current.value, simple_host=True))
control.append(ft.IconButton(icon=ft.icons.FORMAT_ALIGN_LEFT, disabled=not enabled,
ref=transcribe_button_ref,
on_click=lambda _, f=full_file_path: transcribe(f)))
transcribe_buttons.append(transcribe_button_ref)
files_controls.append(ft.Row(control))
if len(files_controls) == 0:
files_controls.append(ft.Text('No mp3 Files found', color='grey'))
return ft.Row([
ft.VerticalDivider(),
ft.Column(controls + [ft.Row([ft.VerticalDivider(), ft.Column(files_controls)])])
]
)
def on_dialog_result(e: ft.FilePickerResultEvent):
path = e.path
if path:
print(f"path is {path}")
try:
if os.path.isdir(path):
tree = utils.build_file_tree(path)
if '.' in tree: # if there is actually a proper file tree
# add to view
file_tree.current.controls.append(
generate_file_tree(path, utils.defaultdict_to_dict(tree))
)
file_tree_empty_text.current.visible = False
page.update()
except e:
print("didn't work aaa") # TODO: fix
def mode_select():
global transcribe_ready
if mm.is_model_loaded():
print("BUG: cannot change mode while model is loaded!")
return
next_mode = current_mode_select.current.value
if next_mode == 'local':
# enable model selects & loads
model_size_select.current.visible = True
model_device_select.current.visible = True
model_load_unload_button.current.visible = True
model_size_select.current.disabled = False
model_device_select.current.disabled = False
whisper_webservice_url_input.current.visible = False
for btn in transcribe_buttons:
btn.current.disabled = True
set_transcribe_ready(False)
elif next_mode == 'webservice':
# enable model selects & loads
model_size_select.current.visible = False
model_device_select.current.visible = False
model_load_unload_button.current.visible = False
model_size_select.current.disabled = True
model_device_select.current.disabled = True
model_load_unload_button.current.disabled = True
current_mode_info_text.current.value = 'Input the URL of the onerahmet/openai-whisper-asr-webservice docker container'
whisper_webservice_url_input.current.visible = True
whisper_webservice_url_input.current.disabled = False
on_url_input(None)
else:
raise Exception(f'BUG: Impossible mode {next_mode} received!')
page.update()
page.client_storage.set('selected_mode', next_mode)
def load_model():
current_mode_info_text.current.value = 'Loading... This may take a while.'
page.update()
paralyze_ui()
try:
mm.set_model(
size=model_size_select.current.value or 'base',
device=model_device_select.current.value or 'auto',
# compute_type=model_bits_select.current.value or '16bit',
)
except Exception as e:
print(f"loading model failed. Exception: {str(e)}")
print(traceback.format_exc())
current_mode_info_text.current.value = f'Loading failed. Reason:\n{str(e)}'
set_transcribe_ready(False)
# raise e
processing_spinner.current.visible = False
if mm.is_model_loaded():
current_mode_info_text.current.value = f'Loaded.'
# if successful, save to shared preferences
page.client_storage.set('model_size', model_size_select.current.value)
page.client_storage.set('device_select', model_device_select.current.value)
# set all transcribe buttons to enabled
set_transcribe_ready(True)
else:
set_transcribe_ready(False)
def unload_model():
# set all transcribe buttons to disabled
paralyze_ui()
if mm.is_model_loaded():
mm.unload_model()
set_transcribe_ready(False)
def paralyze_ui(spinner: bool = True, disable_recording_button: bool = True):
model_size_select.current.disabled = True
model_device_select.current.disabled = True
# model_bits_select.current.disabled = True
model_load_unload_button.current.disabled = True
processing_spinner.current.visible = spinner
current_mode_select.current.disabled = True
record_button.current.disabled = disable_recording_button
model_load_unload_button.current.icon = ft.icons.CLOSE
model_load_unload_button.current.disabled = False
for btn in transcribe_buttons:
btn.current.disabled = True
model_load_unload_button.current.disabled = True
page.update()
def set_transcribe_ready(rdy: bool):
global transcribe_ready
transcribe_ready = rdy
if transcribe_ready:
for btn in transcribe_buttons:
btn.current.disabled = False
model_size_select.current.disabled = True
model_device_select.current.disabled = True
# model_bits_select.current.disabled = True
model_load_unload_button.current.disabled = True
processing_spinner.current.visible = False
model_load_unload_button.current.on_click = lambda _: unload_model()
model_load_unload_button.current.icon = ft.icons.CLOSE
model_load_unload_button.current.disabled = False
record_button.current.disabled = False
if mm.is_model_loaded():
current_mode_select.current.disabled = True
else:
for btn in transcribe_buttons:
btn.current.disabled = True
model_size_select.current.disabled = False
model_device_select.current.disabled = False
# model_bits_select.current.disabled = False
model_load_unload_button.current.disabled = False
model_load_unload_button.current.icon = ft.icons.START
model_load_unload_button.current.on_click = lambda _: load_model()
processing_spinner.current.visible = False
current_mode_select.current.disabled = False
record_button.current.disabled = True
page.update()
def on_url_input(e):
url_value = whisper_webservice_url_input.current.value
# print(url_value)
if validators.url(url_value, simple_host=True):
# print('valid')
page.client_storage.set('webservice_url', url_value)
# set all transcribe buttons to enabled
set_transcribe_ready(True)
else:
# print('invalid')
# set all transcribe buttons to disabled
set_transcribe_ready(False)
page.update()
print(tuple(page.client_storage.get('selected_mic')))
def toggle_recording():
global recording
global rec_stream
global sound_chunks
global recorded_audio
if recording:
print("Stopping recording...")
rec_stream.stop_stream()
while not rec_stream.is_stopped():
pass # wait until stopped
recorded_audio = b"".join(sound_chunks)
set_transcribe_ready(False)
transcribe(recorded_audio)
recording = False
# sound = pygame.mixer.Sound(buffer=recorded_audio) # doesn't work because sampling rate is wrong
record_button.current.bgcolor = "0x000000FF"
set_transcribe_ready(True)
print("done")
# sound.play()
else:
if not transcribe_ready:
print("Can't record, not ready")
return
print("Starting Recording...")
recording = True
sound_chunks = []
def cb(in_data, _frame_count, _time_info, _status):
sound_chunks.append(in_data)
print(_time_info)
return in_data, pyaudio.paContinue
rec_stream = p.open(
format=REC_FORMAT,
channels=REC_CHANNELS,
rate=REC_RATE,
input=True,
frames_per_buffer=REC_CHUNK,
stream_callback=cb
)
rec_stream.start_stream()
record_button.current.bgcolor = "0xFFFF4444"
paralyze_ui(spinner=False, disable_recording_button=False)
def find_recordingdevice_tuple_by_name(search_name: str) -> typing.Tuple[int, str] | None:
return next(((device_id, name) for device_id, name in capture_devices if name == search_name))
# set up file picker
file_picker = ft.FilePicker(on_result=on_dialog_result)
page.overlay.append(file_picker)
page.add(
ft.Text("Flüsterpost", style=ft.TextThemeStyle.TITLE_LARGE),
ft.Divider()
)
mode = page.client_storage.get('selected_mode') if page.client_storage.contains_key('selected_mode') else 'local'
page.add(
ft.ResponsiveRow([
ft.Container(
ft.Column([
ft.Row([
ft.ElevatedButton("Add Folder", on_click=lambda _: file_picker.get_directory_path()),
ft.Container(expand=True),
ft.IconButton(ft.icons.RECORD_VOICE_OVER, ref=record_button,
on_click=lambda _: toggle_recording()),
]),
ft.Dropdown(
ref=mic_select,
options=[ft.dropdown.Option(x[1]) for x in capture_devices],
value=page.client_storage.get('selected_mic')[1] if (
page.client_storage.contains_key('selected_mic') and tuple(
page.client_storage.get('selected_mic')) in capture_devices) else capture_devices[0][1],
height=36,
content_padding=2,
on_change=lambda _: page.client_storage.set('selected_mic', find_recordingdevice_tuple_by_name(
mic_select.current.value)) if mic_select.current.value else None
),
ft.Column(ref=file_tree, scroll=ft.ScrollMode.ALWAYS, expand=True),
# ft.ListView(ref=file_tree),
ft.Text("No Folder Open Yet", style=ft.TextTheme.body_small, color="grey",
ref=file_tree_empty_text),
], expand=True), expand=True, col=4),
ft.Container(expand=True, content=ft.Column(expand=True, controls=[
ft.Column([
ft.Text(
'Select parameters, and then load transcription model.'
if mode == 'local'
else 'Input the URL of the onerahmet/openai-whisper-asr-webservice docker container'
, ref=current_mode_info_text),
ft.Row([
ft.Dropdown(
ref=current_mode_select,
width=160,
hint_text='mode',
value=mode,
on_change=lambda _: mode_select(),
options=[
ft.dropdown.Option('local'),
ft.dropdown.Option('webservice'),
],
),
# === LOCAL MODE ===
ft.Dropdown(
ref=model_size_select,
width=100,
hint_text='model size',
value=page.client_storage.get('model_size') if page.client_storage.contains_key(
'model_size') else 'base',
options=[ft.dropdown.Option(x) for x in mm.ModelSize.__args__],
# __args__ is not perfect here. But works.
visible=mode == 'local',
),
ft.Dropdown(
ref=model_device_select,
width=100,
hint_text='device',
value=page.client_storage.get('device_select') if page.client_storage.contains_key(
'device_select') else 'auto',
options=[ft.dropdown.Option(x) for x in mm.Device.__args__],
visible=mode == 'local',
# __args__ is not perfect here. But works.
),
# ft.Dropdown(
# ref=model_bits_select,
# width=100,
# hint_text='bits',
# value='16bit',
# options=[ft.dropdown.Option(x) for x in mm.ComputeType.__args__] # __args__ is not perfect here. But works.
# ),
ft.IconButton(
icon=ft.icons.START,
ref=model_load_unload_button,
on_click=lambda _: load_model(),
visible=mode == 'local',
),
# === WEBSERVICE MODE ===
ft.TextField(
ref=whisper_webservice_url_input,
visible=mode == 'webservice',
on_change=on_url_input,
hint_text='e.g. http://localhost:9000',
value=page.client_storage.get('webservice_url') if page.client_storage.contains_key(
'webservice_url') else '',
),
# TODO: question mark hint button about what the web service is
# === GENERAL ===
ft.ProgressRing(ref=processing_spinner, visible=False)
])
]),
ft.Container(expand=True, padding=12, border=ft.border.all(2, 'grey'),
alignment=ft.alignment.center,
ref=output_text_container,
content=ft.Column(
[ft.Text('Nothing to see here!', text_align=ft.TextAlign.CENTER)],
ref=output_text_col,
expand=True,
scroll=ft.ScrollMode.ADAPTIVE)),
]), col=8)
], expand=True),
)
# refresh all values, and make sure the right stuff is shown
mode_select()
ft.app(target=main)