tts-anki-dictionary/main.py

284 lines
10 KiB
Python
Raw Normal View History

2024-02-09 01:45:49 +01:00
import os
from typing import Tuple, List
import argparse
import re
from pathlib import Path
from pydub import AudioSegment
import torch
from TTS.api import TTS
import requests
import prompt
OUT_DIR = 'out'
SPEAKER_WAV = 'speaker.wav'
LANG = "de"
tts = None
try:
with open('apikey.secret') as f:
api_key = f.read().strip()
except FileNotFoundError:
print('Couldn\'t read API key from file \'apikey.secret\'w. Does it exist?')
def is_float(s: str) -> bool:
try:
float(s)
return True
except ValueError:
return False
def has_word_characters(s: str) -> bool:
if re.search(r'\w', s):
return True
else:
return False
def transform_string(input_str: str) -> str:
"""
This method transforms strings like "Der Gauner, die Gauner" into der_gauner_die_gauner
"""
output_str = (input_str
.lower()
.replace(' ', '_')
.replace(',', '')
.replace(';', '')
.replace('.', '')
.replace('/', '')
.replace('\\', '')
.replace(']', '')
.replace('[', '')
.replace('(', '')
.replace(')', ''))
return output_str
def get_tts_lazy() -> TTS:
global tts
if tts is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
return tts
def do_tts(text: str, speaker_wav: str, language: str, mp3_path: str, multisample: int = 4):
"""Generates audio file from text."""
os.makedirs('.tmp', exist_ok=True)
# to prevent long extra phrases, run `multisample` times, and take shortest (smallest) sample
for i in range(multisample):
get_tts_lazy().tts_to_file(text=text, speaker_wav=speaker_wav, language=language, file_path=f".tmp/temp{i}.wav")
shortest_sound = None
# load from file and convert to mp3
for i in range(multisample):
sound = AudioSegment.from_file(f".tmp/temp{i}.wav")
if shortest_sound is None:
shortest_sound = sound
elif len(sound) < len(shortest_sound):
shortest_sound = sound
sound = shortest_sound
sound = sound.set_frame_rate(44100) # Set frame rate to 44.1kHz, high quality
sound = sound.set_sample_width(2) # 2 byte (16 bit) samples, high quality
sound = sound.set_channels(2) # make it stereo
# Export as high quality mp3
sound.export(mp3_path, format="mp3", bitrate="192k") # Export with high quality bitrate
temp = "passieren; der Schaden, die Schäden; der Start, die Starts; die Strecke, die Strecken; der Verkehr; wenden; das Zeichen, die Zeichen; aussteigen; ausweichen; die Autobahn, die Autobahnen; der Bord, die Borde; die Brücke, die Brücken; einholen; einsteigen; entgegenkommen; fort; freigeben; der Hafen, die Häfen; der Halt, die Halte; die Kurve, die Kurven; laden; mobil; der Parkplatz, die Parkplätze; rollen; das Signal, die Signale; sperren; die Station, die Stationen; stoppen; das Tempo, die Tempos; das Ticket, die Tickets; der Transport, die Transporte; transportieren; der Tunnel, die Tunnel; der Unfall, die Unfälle; verkehren; verpassen"
def extract_words_from_cards(cards: [Tuple[str, str]], temp: float = 0.0) -> List[str]:
url = "https://api.perplexity.ai/chat/completions"
left_sides = [card[0] for card in cards]
query_words = ';'.join(left_sides) # german words
payload = {
"model": "mixtral-8x7b-instruct",
# "model": "pplx-70b-chat",
"messages": [
{
"role": "system",
"content": prompt.CARDS_TO_WORDS_PROMPT # prompt.CARDS_TO_WORDS_PROMPT
# "content": prompt.LLAMA_CARDS_TO_WORDS_PROMPT
},
{
"role": "user",
"content": "Here are Anki Cards to transform into speakable phrases: \n" + query_words +
"\nMake sure to ONLY output a string of semicolon-separated speakable phrases. DO NOT write anything else!"
}
],
"temperature": temp,
"presence_penalty": 0
}
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
}
while True:
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print("Perplexity API error")
user_input = input("Retry? [Y/n] > ")
if user_input.strip().lower() in ['n', 'no']:
print("exiting...")
exit(0)
else:
continue
else:
break
print((response.json()))
# split response at ;, and then return as list
content = response.json()['choices'][0]['message']['content']
print(f'query: {query_words}')
print(f'content: {content}')
words = content.split(";")
words = [word.split('\n')[0].strip() for word in words] # remove eventual newlines (e.g. comments)
if len(words[-1]) == 0 or not has_word_characters(words[-1]):
words = words[:-1]
return words
def process_text_to_audio(phrases_and_filenames: [(str, str)]):
os.makedirs(OUT_DIR, exist_ok=True)
for idx, phrase_and_filename in enumerate(phrases_and_filenames, 1):
phrase, filename = phrase_and_filename
mp3_path = os.path.join(OUT_DIR, f"{filename}.mp3")
do_tts(phrase, 'speaker.wav', 'de', mp3_path)
def main():
parser = argparse.ArgumentParser(description="Parse markdown note for Anki cards.")
parser.add_argument('note_file', type=str, help='The path of the markdown note file.')
parser.add_argument('--out-file', type=str, default='out.md', help='The output file.')
parser.add_argument('--obsidian-format', action='store_true', help='Use Obsidian path format.')
parser.add_argument("--batch-size", type=int, default=64,
help='Number of cards sent to the LLM at one time, and processed in bulk. Default 64')
parser.add_argument('--multisample', type=int, default=4,
help='Number of audio generations per batch. Reduces audio with arbitrary sounds for short cards. Default 4')
parser.add_argument('--multisample-multiply-limit', type=int, default=8,
help='If a phrase is shorter than the multisample multiply limit, significantly more audio generations (generally *3) are done to improve quality. Set to 0 to disable. Default 8')
parser.add_argument('--multisample-multiply', type=int, default=3,
help='Sets the multiplier for additional audio generations when a phrase is shorter than the multisample multiply limit. Default is 3')
args = parser.parse_args()
# Check if the note file exists
if not Path(args.note_file).exists():
raise FileNotFoundError(f"File {args.note_file} does not exist.")
# Read file and parse for Anki cards
with open(args.note_file, 'r') as f:
content = f.read()
matches = re.findall(r"^Q: (.+)\nA: (.+)\n", content, re.MULTILINE)
# truncate for debug
# matches = matches[:129]
batch_size = args.batch_size
# process in blocks of batch_size
out_content = content
for i in range(len(matches) // batch_size + 1):
to_match = matches[i * batch_size:i * batch_size + batch_size]
correct_words_generated = False
words = None
subbatch_size = batch_size
cur_temp = 0
# process cards
while not correct_words_generated:
batches = [to_match[i:i + subbatch_size] for i in range(0, len(to_match), subbatch_size)]
words = []
for batch in batches:
words += extract_words_from_cards(batch, temp=cur_temp)
if len(words) != len(to_match):
print(f'generated words len ({len(words)}) != matches len ({len(to_match)})')
print(
f'Current Batch Size is {subbatch_size}, temp is {cur_temp}. If this happens repeatedly, try reducing the batch size.')
userinput = input('Try again? [Y/n/split/temp] > ')
if userinput.strip().lower() in ['n', 'no']:
print("aborting...")
exit(0)
elif userinput.strip().lower() in ['s', 'split']:
subbatch_size = max(1, subbatch_size // 2)
# generate batches
elif userinput.strip().lower() in ['t', 'temp']:
while True:
new_temp = input(f"cur temp: {cur_temp}. Input new temp (0-2) > ")
if is_float(new_temp) and 0 <= float(new_temp) <= 2:
cur_temp = float(new_temp)
break
else:
print("Must be numeric (float) between 0 and 2")
else:
print("trying again...")
else:
correct_words_generated = True
# assert len(words) == len(to_match), f'generated words len ({len(words)}) != matches len ({len(to_match)})'
print(str(words))
filenames = []
for j, word in enumerate(words):
print(f'speaker-ifying word {word} ({j + i * batch_size} of {len(matches)})')
word = word.strip()
filename = f"{transform_string(word)}.mp3"
filenames.append(filename)
multisample = args.multisample
if len(word) < args.multisample_multiply_limit:
multisample *= args.multisample_multiply # generate 3x as many for short words / phrases
do_tts(word, SPEAKER_WAV, LANG, f'./{OUT_DIR}/{filename}', multisample=multisample)
for i, (question, answer) in enumerate(to_match):
out_content = out_content.replace(f"Q: {question}",
f"Q: {question} {'![[' if args.obsidian_format else f'![](./{OUT_DIR}/'}{filenames[i]}{']]' if args.obsidian_format else ')'}")
# open out file for writing
out_file = args.out_file
with open(out_file, 'w') as f:
f.write(out_content)
if __name__ == '__main__':
main()