tts-anki-dictionary/main.py

284 lines
10 KiB
Python

import os
from typing import Tuple, List
import argparse
import re
from pathlib import Path
from pydub import AudioSegment
import torch
from TTS.api import TTS
import requests
import prompt
OUT_DIR = 'out'
SPEAKER_WAV = 'speaker.wav'
LANG = "de"
tts = None
try:
with open('apikey.secret') as f:
api_key = f.read().strip()
except FileNotFoundError:
print('Couldn\'t read API key from file \'apikey.secret\'w. Does it exist?')
def is_float(s: str) -> bool:
try:
float(s)
return True
except ValueError:
return False
def has_word_characters(s: str) -> bool:
if re.search(r'\w', s):
return True
else:
return False
def transform_string(input_str: str) -> str:
"""
This method transforms strings like "Der Gauner, die Gauner" into der_gauner_die_gauner
"""
output_str = (input_str
.lower()
.replace(' ', '_')
.replace(',', '')
.replace(';', '')
.replace('.', '')
.replace('/', '')
.replace('\\', '')
.replace(']', '')
.replace('[', '')
.replace('(', '')
.replace(')', ''))
return output_str
def get_tts_lazy() -> TTS:
global tts
if tts is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
return tts
def do_tts(text: str, speaker_wav: str, language: str, mp3_path: str, multisample: int = 4):
"""Generates audio file from text."""
os.makedirs('.tmp', exist_ok=True)
# to prevent long extra phrases, run `multisample` times, and take shortest (smallest) sample
for i in range(multisample):
get_tts_lazy().tts_to_file(text=text, speaker_wav=speaker_wav, language=language, file_path=f".tmp/temp{i}.wav")
shortest_sound = None
# load from file and convert to mp3
for i in range(multisample):
sound = AudioSegment.from_file(f".tmp/temp{i}.wav")
if shortest_sound is None:
shortest_sound = sound
elif len(sound) < len(shortest_sound):
shortest_sound = sound
sound = shortest_sound
sound = sound.set_frame_rate(44100) # Set frame rate to 44.1kHz, high quality
sound = sound.set_sample_width(2) # 2 byte (16 bit) samples, high quality
sound = sound.set_channels(2) # make it stereo
# Export as high quality mp3
sound.export(mp3_path, format="mp3", bitrate="192k") # Export with high quality bitrate
temp = "passieren; der Schaden, die Schäden; der Start, die Starts; die Strecke, die Strecken; der Verkehr; wenden; das Zeichen, die Zeichen; aussteigen; ausweichen; die Autobahn, die Autobahnen; der Bord, die Borde; die Brücke, die Brücken; einholen; einsteigen; entgegenkommen; fort; freigeben; der Hafen, die Häfen; der Halt, die Halte; die Kurve, die Kurven; laden; mobil; der Parkplatz, die Parkplätze; rollen; das Signal, die Signale; sperren; die Station, die Stationen; stoppen; das Tempo, die Tempos; das Ticket, die Tickets; der Transport, die Transporte; transportieren; der Tunnel, die Tunnel; der Unfall, die Unfälle; verkehren; verpassen"
def extract_words_from_cards(cards: [Tuple[str, str]], temp: float = 0.0) -> List[str]:
url = "https://api.perplexity.ai/chat/completions"
left_sides = [card[0] for card in cards]
query_words = ';'.join(left_sides) # german words
payload = {
"model": "mixtral-8x7b-instruct",
# "model": "pplx-70b-chat",
"messages": [
{
"role": "system",
"content": prompt.CARDS_TO_WORDS_PROMPT # prompt.CARDS_TO_WORDS_PROMPT
# "content": prompt.LLAMA_CARDS_TO_WORDS_PROMPT
},
{
"role": "user",
"content": "Here are Anki Cards to transform into speakable phrases: \n" + query_words +
"\nMake sure to ONLY output a string of semicolon-separated speakable phrases. DO NOT write anything else!"
}
],
"temperature": temp,
"presence_penalty": 0
}
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
}
while True:
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print("Perplexity API error")
user_input = input("Retry? [Y/n] > ")
if user_input.strip().lower() in ['n', 'no']:
print("exiting...")
exit(0)
else:
continue
else:
break
print((response.json()))
# split response at ;, and then return as list
content = response.json()['choices'][0]['message']['content']
print(f'query: {query_words}')
print(f'content: {content}')
words = content.split(";")
words = [word.split('\n')[0].strip() for word in words] # remove eventual newlines (e.g. comments)
if len(words[-1]) == 0 or not has_word_characters(words[-1]):
words = words[:-1]
return words
def process_text_to_audio(phrases_and_filenames: [(str, str)]):
os.makedirs(OUT_DIR, exist_ok=True)
for idx, phrase_and_filename in enumerate(phrases_and_filenames, 1):
phrase, filename = phrase_and_filename
mp3_path = os.path.join(OUT_DIR, f"{filename}.mp3")
do_tts(phrase, 'speaker.wav', 'de', mp3_path)
def main():
parser = argparse.ArgumentParser(description="Parse markdown note for Anki cards.")
parser.add_argument('note_file', type=str, help='The path of the markdown note file.')
parser.add_argument('--out-file', type=str, default='out.md', help='The output file.')
parser.add_argument('--obsidian-format', action='store_true', help='Use Obsidian path format.')
parser.add_argument("--batch-size", type=int, default=64,
help='Number of cards sent to the LLM at one time, and processed in bulk. Default 64')
parser.add_argument('--multisample', type=int, default=4,
help='Number of audio generations per batch. Reduces audio with arbitrary sounds for short cards. Default 4')
parser.add_argument('--multisample-multiply-limit', type=int, default=8,
help='If a phrase is shorter than the multisample multiply limit, significantly more audio generations (generally *3) are done to improve quality. Set to 0 to disable. Default 8')
parser.add_argument('--multisample-multiply', type=int, default=3,
help='Sets the multiplier for additional audio generations when a phrase is shorter than the multisample multiply limit. Default is 3')
args = parser.parse_args()
# Check if the note file exists
if not Path(args.note_file).exists():
raise FileNotFoundError(f"File {args.note_file} does not exist.")
# Read file and parse for Anki cards
with open(args.note_file, 'r') as f:
content = f.read()
matches = re.findall(r"^Q: (.+)\nA: (.+)\n", content, re.MULTILINE)
# truncate for debug
# matches = matches[:129]
batch_size = args.batch_size
# process in blocks of batch_size
out_content = content
for i in range(len(matches) // batch_size + 1):
to_match = matches[i * batch_size:i * batch_size + batch_size]
correct_words_generated = False
words = None
subbatch_size = batch_size
cur_temp = 0
# process cards
while not correct_words_generated:
batches = [to_match[i:i + subbatch_size] for i in range(0, len(to_match), subbatch_size)]
words = []
for batch in batches:
words += extract_words_from_cards(batch, temp=cur_temp)
if len(words) != len(to_match):
print(f'generated words len ({len(words)}) != matches len ({len(to_match)})')
print(
f'Current Batch Size is {subbatch_size}, temp is {cur_temp}. If this happens repeatedly, try reducing the batch size.')
userinput = input('Try again? [Y/n/split/temp] > ')
if userinput.strip().lower() in ['n', 'no']:
print("aborting...")
exit(0)
elif userinput.strip().lower() in ['s', 'split']:
subbatch_size = max(1, subbatch_size // 2)
# generate batches
elif userinput.strip().lower() in ['t', 'temp']:
while True:
new_temp = input(f"cur temp: {cur_temp}. Input new temp (0-2) > ")
if is_float(new_temp) and 0 <= float(new_temp) <= 2:
cur_temp = float(new_temp)
break
else:
print("Must be numeric (float) between 0 and 2")
else:
print("trying again...")
else:
correct_words_generated = True
# assert len(words) == len(to_match), f'generated words len ({len(words)}) != matches len ({len(to_match)})'
print(str(words))
filenames = []
for j, word in enumerate(words):
print(f'speaker-ifying word {word} ({j + i * batch_size} of {len(matches)})')
word = word.strip()
filename = f"{transform_string(word)}.mp3"
filenames.append(filename)
multisample = args.multisample
if len(word) < args.multisample_multiply_limit:
multisample *= args.multisample_multiply # generate 3x as many for short words / phrases
do_tts(word, SPEAKER_WAV, LANG, f'./{OUT_DIR}/{filename}', multisample=multisample)
for i, (question, answer) in enumerate(to_match):
out_content = out_content.replace(f"Q: {question}",
f"Q: {question} {'![[' if args.obsidian_format else f'![](./{OUT_DIR}/'}{filenames[i]}{']]' if args.obsidian_format else ')'}")
# open out file for writing
out_file = args.out_file
with open(out_file, 'w') as f:
f.write(out_content)
if __name__ == '__main__':
main()