323 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			323 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import pprint
 | |
| import traceback
 | |
| 
 | |
| import utils
 | |
| import flet as ft
 | |
| 
 | |
| from typing import DefaultDict
 | |
| 
 | |
| import pygame
 | |
| 
 | |
| import nn_model_manager as mm
 | |
| 
 | |
| 
 | |
| def main(page):
 | |
|     pygame.mixer.init()
 | |
| 
 | |
|     first_name = ft.Ref[ft.TextField]()
 | |
|     last_name = ft.Ref[ft.TextField]()
 | |
|     greetings = ft.Ref[ft.Column]()
 | |
| 
 | |
|     file_tree = ft.Ref[ft.Column]()
 | |
|     file_tree_empty_text = ft.Ref[ft.Text]()
 | |
|     
 | |
|     load_model_text = ft.Ref[ft.Text]()
 | |
|     model_size_select = ft.Ref[ft.Dropdown]()
 | |
|     model_device_select = ft.Ref[ft.Dropdown]()
 | |
|     # model_bits_select = ft.Ref[ft.Dropdown]()
 | |
|     model_load_unload_button = ft.Ref[ft.IconButton]()
 | |
|     model_loading_spinner = ft.Ref[ft.ProgressRing]()
 | |
|     
 | |
|     transcribe_buttons: list[ft.Ref[ft.IconButton]] = []
 | |
|     
 | |
|     output_text_container = ft.Ref[ft.Container]()
 | |
|     output_text_col = ft.Ref[ft.Column]()
 | |
| 
 | |
|     def generate_file_tree(path: str, tree_dict: dict | DefaultDict):
 | |
|         if path[-1] == os.sep:
 | |
|             path = path[:-1]
 | |
| 
 | |
|         folder_name = utils.get_last_segment(path)
 | |
|         print(f"DEBUG: generating tree for folder {folder_name}")
 | |
| 
 | |
|         # find folders, and add dict for each
 | |
|         print(f"adding name {folder_name} to ui")
 | |
| 
 | |
|         controls = [
 | |
|             ft.Row(
 | |
|                 [
 | |
|                     ft.Icon(ft.icons.FOLDER, color=ft.colors.BLUE),
 | |
|                     ft.Text(folder_name, size=14, weight=ft.FontWeight.BOLD),
 | |
|                 ]
 | |
|             )
 | |
|         ]
 | |
| 
 | |
|         for folder_name, value in tree_dict.items():
 | |
|             if folder_name == utils.FILES_KEY or folder_name == '.':
 | |
|                 continue  # skip for now
 | |
| 
 | |
|             controls.append(generate_file_tree(path + os.sep + folder_name, value))
 | |
| 
 | |
|         # now folders are there, let's do files
 | |
|         if utils.FILES_KEY not in tree_dict and '.' in tree_dict:
 | |
|             tree_dict = tree_dict['.']  # if root dir, enter root dir (.) directory
 | |
| 
 | |
|         files_controls = []
 | |
| 
 | |
|         for file in tree_dict[utils.FILES_KEY]:
 | |
|             control = [ft.Text(file)]
 | |
|             
 | |
|             if not file.endswith('.mp3'):
 | |
|                 continue
 | |
| 
 | |
|             def start_playing(filepath: str, button_ref: ft.Ref[ft.IconButton]):
 | |
|                 print(f"trying to play {filepath}...")
 | |
|                 if pygame.mixer.music.get_busy() or not os.path.isfile(filepath):
 | |
|                     return
 | |
| 
 | |
|                 print("starting playback")
 | |
| 
 | |
|                 pygame.mixer.music.load(filepath)
 | |
|                 pygame.mixer.music.play()
 | |
| 
 | |
|                 button_ref.current.icon = ft.icons.PAUSE_CIRCLE_FILLED_OUTLINED
 | |
|                 button_ref.current.on_click = lambda _, f=filepath, r=button_ref: stop_playing(f, r)
 | |
|                 page.update()
 | |
| 
 | |
|             def stop_playing(filepath: str, button_ref: ft.Ref[ft.IconButton]):
 | |
|                 print("stopping playback")
 | |
| 
 | |
|                 pygame.mixer.music.stop()
 | |
| 
 | |
|                 button_ref.current.icon = ft.icons.PLAY_CIRCLE_OUTLINED
 | |
|                 button_ref.current.on_click = lambda _, f=filepath, r=button_ref: start_playing(f, r)
 | |
|                 page.update()
 | |
| 
 | |
|             full_file_path = path + os.sep + file
 | |
| 
 | |
|             _button_ref = ft.Ref[ft.IconButton]()
 | |
| 
 | |
|             control.append(ft.IconButton(icon=ft.icons.PLAY_CIRCLE_OUTLINED, ref=_button_ref,
 | |
|                                 on_click=lambda _, f=full_file_path, r=_button_ref: start_playing(f, r)))
 | |
|             
 | |
|             def transcribe(filepath: str):
 | |
|                 print(f"DEBUG: trying to transcribe file {filepath}")
 | |
|                 if not mm.is_model_loaded() or not filepath.endswith('.mp3'):
 | |
|                     return
 | |
|                 
 | |
|                 print(f"DEBUG: starting transcription")
 | |
|                 output_text_container.current.alignment = ft.alignment.center
 | |
|                 output_text_col.current.controls = [ft.ProgressRing()]
 | |
|                 
 | |
|                 # set all transcribe buttons to disabled
 | |
|                 for btn in transcribe_buttons:
 | |
|                     btn.current.disabled = True
 | |
|                 page.update()
 | |
|                 
 | |
|                 try:
 | |
|                     segments, info = mm.transcribe_from_file(filepath)
 | |
|                     
 | |
|                     txt = ''
 | |
|                     
 | |
|                     for seg in segments:
 | |
|                         txt += seg.text + '\n'
 | |
|                     
 | |
|                     output_text_container.current.alignment = ft.alignment.top_left
 | |
|                     output_text_col.current.controls = [ft.Text(txt, selectable=True)]  # TODO
 | |
|                         
 | |
|                 except Exception as e:
 | |
|                     output_text_container.current.alignment = ft.alignment.center
 | |
|                     output_text_col.current.controls = [ft.Text(f"Transcribing failed: {str(e)}")]  # TODO
 | |
|                     
 | |
|                 finally:
 | |
|                     # set all transcribe buttons to disabled
 | |
|                     for btn in transcribe_buttons:
 | |
|                         btn.current.disabled = False
 | |
|                     page.update()
 | |
|                     
 | |
|                 
 | |
|             transcribe_button_ref = ft.Ref[ft.IconButton]()
 | |
|             
 | |
|             control.append(ft.IconButton(icon=ft.icons.FORMAT_ALIGN_LEFT, disabled=not mm.is_model_loaded(), ref=transcribe_button_ref,
 | |
|                                          on_click=lambda _, f=full_file_path: transcribe(f)))
 | |
|             
 | |
|             transcribe_buttons.append(transcribe_button_ref)
 | |
| 
 | |
|             files_controls.append(ft.Row(control))
 | |
| 
 | |
|         if len(files_controls) == 0:
 | |
|             files_controls.append(ft.Text('No mp3 Files found', color='grey'))
 | |
| 
 | |
|         return ft.Row([
 | |
|             ft.VerticalDivider(),
 | |
|             ft.Column(controls + [ft.Row([ft.VerticalDivider(), ft.Column(files_controls)])])
 | |
|         ]
 | |
|         )
 | |
| 
 | |
|     def btn_click(e):
 | |
|         greetings.current.controls.append(
 | |
|             ft.Text(f"Hello, {first_name.current.value} {last_name.current.value}!")
 | |
|         )
 | |
|         first_name.current.value = ""
 | |
|         last_name.current.value = ""
 | |
|         page.update()
 | |
|         first_name.current.focus()
 | |
| 
 | |
|     def on_dialog_result(e: ft.FilePickerResultEvent):
 | |
|         path = e.path
 | |
|         if path:
 | |
|             print(f"path is {path}")
 | |
|             try:
 | |
|                 if os.path.isdir(path):
 | |
|                     tree = utils.build_file_tree(path)
 | |
| 
 | |
|                     if '.' in tree:  # if there is actually a proper file tree
 | |
|                         # add to view
 | |
|                         file_tree.current.controls.append(
 | |
|                             generate_file_tree(path, utils.defaultdict_to_dict(tree))
 | |
|                         )
 | |
|                         file_tree_empty_text.current.visible = False
 | |
| 
 | |
|                     page.update()
 | |
|             except e:
 | |
|                 print("didn't work aaa")  # TODO: fix
 | |
|     
 | |
|     def load_model():
 | |
|         
 | |
|         load_model_text.current.value = 'Loading... This may take a while.'
 | |
|         
 | |
|         model_size_select.current.disabled = True
 | |
|         model_device_select.current.disabled = True
 | |
|         # model_bits_select.current.disabled = True
 | |
|         model_load_unload_button.current.disabled = True
 | |
|         model_loading_spinner.current.visible = True
 | |
|         page.update()
 | |
|         
 | |
|         try:
 | |
|             mm.set_model(
 | |
|                 size=model_size_select.current.value or 'base',
 | |
|                 device=model_device_select.current.value or 'auto',
 | |
|                 # compute_type=model_bits_select.current.value or '16bit',
 | |
|             )
 | |
|         except Exception as e:
 | |
|             print(f"loading model failed. Exception: {str(e)}")
 | |
|             print(traceback.format_exc())
 | |
|             load_model_text.current.value = f'Loading failed. Reason:\n{str(e)}'
 | |
|             model_size_select.current.disabled = False
 | |
|             model_device_select.current.disabled = False
 | |
|             # model_bits_select.current.disabled = False
 | |
|             
 | |
|             # raise e
 | |
|                 
 | |
|         model_loading_spinner.current.visible = False
 | |
|         model_load_unload_button.current.disabled = False
 | |
|             
 | |
|         if mm.is_model_loaded():
 | |
|             load_model_text.current.value = f'Loaded.'
 | |
|             model_load_unload_button.current.icon = ft.icons.CLOSE
 | |
|             model_load_unload_button.current.on_click = lambda _: unload_model()
 | |
|             
 | |
|             # if successful, save to shared preferences
 | |
|             page.client_storage.set('model_size', model_size_select.current.value)
 | |
|             page.client_storage.set('device_select', model_device_select.current.value)
 | |
|             
 | |
|             # set all transcribe buttons to enabled
 | |
|             for btn in transcribe_buttons:
 | |
|                 btn.current.disabled = False
 | |
|         
 | |
|         page.update()
 | |
|             
 | |
|     def unload_model():
 | |
|         model_load_unload_button.current.disabled = True
 | |
|         
 | |
|         # set all transcribe buttons to disabled
 | |
|         for btn in transcribe_buttons:
 | |
|             btn.current.disabled = True
 | |
|             
 | |
|         page.update()
 | |
|         
 | |
|         if mm.is_model_loaded():
 | |
|             mm.unload_model()
 | |
|             
 | |
|         load_model_text.current.value = 'Select parameters, and then load transcription model.'
 | |
|         model_size_select.current.disabled = False
 | |
|         model_device_select.current.disabled = False
 | |
|         # model_bits_select.current.disabled = False
 | |
|         model_load_unload_button.current.disabled = False
 | |
|         model_load_unload_button.current.icon = ft.icons.START
 | |
|         model_load_unload_button.current.on_click = lambda _: load_model()
 | |
|         model_loading_spinner.current.visible = False
 | |
|         page.update()
 | |
|         
 | |
|             
 | |
|             
 | |
| 
 | |
|     # set up file picker
 | |
|     file_picker = ft.FilePicker(on_result=on_dialog_result)
 | |
| 
 | |
|     page.overlay.append(file_picker)
 | |
| 
 | |
|     page.add(
 | |
|         ft.Text("Flüsterpost", style=ft.TextThemeStyle.TITLE_LARGE),
 | |
|         ft.Divider()
 | |
|     )
 | |
| 
 | |
|     page.add(
 | |
|         ft.ResponsiveRow([
 | |
|             ft.Container(
 | |
|                 ft.Column([
 | |
|                     ft.ElevatedButton("Add Folder", on_click=lambda _: file_picker.get_directory_path()),
 | |
|                     ft.Column(ref=file_tree, scroll=ft.ScrollMode.ALWAYS, expand=True),
 | |
|                     # ft.ListView(ref=file_tree),
 | |
|                     ft.Text("No Folder Open Yet", style=ft.TextTheme.body_small, color="grey",
 | |
|                             ref=file_tree_empty_text),
 | |
|                 ], expand=True), expand=True, col=4),
 | |
|             ft.Container(expand=True, content=ft.Column(expand=True, controls=[
 | |
|                 ft.Column([
 | |
|                     ft.Text('Select parameters, and then load transcription model.', ref=load_model_text),
 | |
|                     ft.Row([
 | |
|                         ft.Dropdown(
 | |
|                             ref=model_size_select,
 | |
|                             width=100,
 | |
|                             hint_text='model size',
 | |
|                             value=page.client_storage.get('model_size') if page.client_storage.contains_key('model_size') else 'base',
 | |
|                             options=[ft.dropdown.Option(x) for x in mm.ModelSize.__args__],  # __args__ is not perfect here. But works.
 | |
|                         ),
 | |
|                         ft.Dropdown(
 | |
|                             ref=model_device_select,
 | |
|                             width=100,
 | |
|                             hint_text='device',
 | |
|                             value=page.client_storage.get('device_select') if page.client_storage.contains_key('device_select') else 'auto',
 | |
|                             options=[ft.dropdown.Option(x) for x in mm.Device.__args__]  # __args__ is not perfect here. But works.
 | |
|                         ),
 | |
|                         # ft.Dropdown(
 | |
|                         #    ref=model_bits_select,
 | |
|                         #    width=100,
 | |
|                         #    hint_text='bits',
 | |
|                         #    value='16bit',
 | |
|                         #    options=[ft.dropdown.Option(x) for x in mm.ComputeType.__args__]  # __args__ is not perfect here. But works.
 | |
|                         #),
 | |
|                         ft.IconButton(
 | |
|                             icon=ft.icons.START,
 | |
|                             ref=model_load_unload_button,
 | |
|                             on_click=lambda _: load_model(),
 | |
|                         ),
 | |
|                         ft.ProgressRing(ref=model_loading_spinner, visible=False)
 | |
|                     ])
 | |
|                 ]),
 | |
|                 ft.Container(expand=True, padding=12, border=ft.border.all(2, 'grey'), 
 | |
|                              alignment=ft.alignment.center,
 | |
|                              ref=output_text_container,
 | |
|                              content=ft.Column(
 | |
|                                  [ft.Text('Nothing to see here!', text_align=ft.TextAlign.CENTER)],
 | |
|                                                 ref=output_text_col,
 | |
|                                                 expand=True,
 | |
|                                                 scroll=ft.ScrollMode.ADAPTIVE)),
 | |
|             ]), col=8)
 | |
|         ], expand=True),
 | |
|     )
 | |
| 
 | |
| 
 | |
| ft.app(target=main)
 |