284 lines
10 KiB
Python
284 lines
10 KiB
Python
import os
|
|
from typing import Tuple, List
|
|
|
|
import argparse
|
|
import re
|
|
from pathlib import Path
|
|
from pydub import AudioSegment
|
|
import torch
|
|
from TTS.api import TTS
|
|
|
|
import requests
|
|
|
|
import prompt
|
|
|
|
OUT_DIR = 'out'
|
|
SPEAKER_WAV = 'speaker.wav'
|
|
LANG = "de"
|
|
|
|
tts = None
|
|
|
|
try:
|
|
with open('apikey.secret') as f:
|
|
api_key = f.read().strip()
|
|
except FileNotFoundError:
|
|
print('Couldn\'t read API key from file \'apikey.secret\'w. Does it exist?')
|
|
|
|
|
|
def is_float(s: str) -> bool:
|
|
try:
|
|
float(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
def has_word_characters(s: str) -> bool:
|
|
if re.search(r'\w', s):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def transform_string(input_str: str) -> str:
|
|
"""
|
|
This method transforms strings like "Der Gauner, die Gauner" into der_gauner_die_gauner
|
|
"""
|
|
output_str = (input_str
|
|
.lower()
|
|
.replace(' ', '_')
|
|
.replace(',', '')
|
|
.replace(';', '')
|
|
.replace('.', '')
|
|
.replace('/', '')
|
|
.replace('\\', '')
|
|
.replace(']', '')
|
|
.replace('[', '')
|
|
.replace('(', '')
|
|
.replace(')', ''))
|
|
return output_str
|
|
|
|
|
|
def get_tts_lazy() -> TTS:
|
|
global tts
|
|
if tts is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
|
return tts
|
|
|
|
|
|
def do_tts(text: str, speaker_wav: str, language: str, mp3_path: str, multisample: int = 4):
|
|
"""Generates audio file from text."""
|
|
|
|
os.makedirs('.tmp', exist_ok=True)
|
|
|
|
# to prevent long extra phrases, run `multisample` times, and take shortest (smallest) sample
|
|
for i in range(multisample):
|
|
get_tts_lazy().tts_to_file(text=text, speaker_wav=speaker_wav, language=language, file_path=f".tmp/temp{i}.wav")
|
|
|
|
shortest_sound = None
|
|
|
|
# load from file and convert to mp3
|
|
for i in range(multisample):
|
|
sound = AudioSegment.from_file(f".tmp/temp{i}.wav")
|
|
if shortest_sound is None:
|
|
shortest_sound = sound
|
|
elif len(sound) < len(shortest_sound):
|
|
shortest_sound = sound
|
|
|
|
sound = shortest_sound
|
|
|
|
sound = sound.set_frame_rate(44100) # Set frame rate to 44.1kHz, high quality
|
|
sound = sound.set_sample_width(2) # 2 byte (16 bit) samples, high quality
|
|
sound = sound.set_channels(2) # make it stereo
|
|
|
|
# Export as high quality mp3
|
|
sound.export(mp3_path, format="mp3", bitrate="192k") # Export with high quality bitrate
|
|
|
|
|
|
temp = "passieren; der Schaden, die Schäden; der Start, die Starts; die Strecke, die Strecken; der Verkehr; wenden; das Zeichen, die Zeichen; aussteigen; ausweichen; die Autobahn, die Autobahnen; der Bord, die Borde; die Brücke, die Brücken; einholen; einsteigen; entgegenkommen; fort; freigeben; der Hafen, die Häfen; der Halt, die Halte; die Kurve, die Kurven; laden; mobil; der Parkplatz, die Parkplätze; rollen; das Signal, die Signale; sperren; die Station, die Stationen; stoppen; das Tempo, die Tempos; das Ticket, die Tickets; der Transport, die Transporte; transportieren; der Tunnel, die Tunnel; der Unfall, die Unfälle; verkehren; verpassen"
|
|
|
|
|
|
def extract_words_from_cards(cards: [Tuple[str, str]], temp: float = 0.0) -> List[str]:
|
|
url = "https://api.perplexity.ai/chat/completions"
|
|
|
|
left_sides = [card[0] for card in cards]
|
|
query_words = ';'.join(left_sides) # german words
|
|
|
|
payload = {
|
|
"model": "mixtral-8x7b-instruct",
|
|
# "model": "pplx-70b-chat",
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": prompt.CARDS_TO_WORDS_PROMPT # prompt.CARDS_TO_WORDS_PROMPT
|
|
# "content": prompt.LLAMA_CARDS_TO_WORDS_PROMPT
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Here are Anki Cards to transform into speakable phrases: \n" + query_words +
|
|
"\nMake sure to ONLY output a string of semicolon-separated speakable phrases. DO NOT write anything else!"
|
|
}
|
|
],
|
|
"temperature": temp,
|
|
"presence_penalty": 0
|
|
}
|
|
headers = {
|
|
"accept": "application/json",
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {api_key}",
|
|
}
|
|
|
|
while True:
|
|
response = requests.post(url, json=payload, headers=headers)
|
|
|
|
if response.status_code != 200:
|
|
print("Perplexity API error")
|
|
user_input = input("Retry? [Y/n] > ")
|
|
if user_input.strip().lower() in ['n', 'no']:
|
|
print("exiting...")
|
|
exit(0)
|
|
else:
|
|
continue
|
|
else:
|
|
break
|
|
|
|
print((response.json()))
|
|
|
|
# split response at ;, and then return as list
|
|
|
|
content = response.json()['choices'][0]['message']['content']
|
|
|
|
print(f'query: {query_words}')
|
|
print(f'content: {content}')
|
|
|
|
words = content.split(";")
|
|
words = [word.split('\n')[0].strip() for word in words] # remove eventual newlines (e.g. comments)
|
|
if len(words[-1]) == 0 or not has_word_characters(words[-1]):
|
|
words = words[:-1]
|
|
return words
|
|
|
|
|
|
def process_text_to_audio(phrases_and_filenames: [(str, str)]):
|
|
os.makedirs(OUT_DIR, exist_ok=True)
|
|
for idx, phrase_and_filename in enumerate(phrases_and_filenames, 1):
|
|
phrase, filename = phrase_and_filename
|
|
mp3_path = os.path.join(OUT_DIR, f"{filename}.mp3")
|
|
do_tts(phrase, 'speaker.wav', 'de', mp3_path)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Parse markdown note for Anki cards.")
|
|
parser.add_argument('note_file', type=str, help='The path of the markdown note file.')
|
|
parser.add_argument('--out-file', type=str, default='out.md', help='The output file.')
|
|
parser.add_argument('--obsidian-format', action='store_true', help='Use Obsidian path format.')
|
|
parser.add_argument("--batch-size", type=int, default=64,
|
|
help='Number of cards sent to the LLM at one time, and processed in bulk. Default 64')
|
|
parser.add_argument('--multisample', type=int, default=4,
|
|
help='Number of audio generations per batch. Reduces audio with arbitrary sounds for short cards. Default 4')
|
|
parser.add_argument('--multisample-multiply-limit', type=int, default=8,
|
|
help='If a phrase is shorter than the multisample multiply limit, significantly more audio generations (generally *3) are done to improve quality. Set to 0 to disable. Default 8')
|
|
parser.add_argument('--multisample-multiply', type=int, default=3,
|
|
help='Sets the multiplier for additional audio generations when a phrase is shorter than the multisample multiply limit. Default is 3')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check if the note file exists
|
|
if not Path(args.note_file).exists():
|
|
raise FileNotFoundError(f"File {args.note_file} does not exist.")
|
|
|
|
# Read file and parse for Anki cards
|
|
with open(args.note_file, 'r') as f:
|
|
content = f.read()
|
|
|
|
matches = re.findall(r"^Q: (.+)\nA: (.+)\n", content, re.MULTILINE)
|
|
|
|
# truncate for debug
|
|
# matches = matches[:129]
|
|
|
|
batch_size = args.batch_size
|
|
|
|
# process in blocks of batch_size
|
|
|
|
out_content = content
|
|
|
|
for i in range(len(matches) // batch_size + 1):
|
|
to_match = matches[i * batch_size:i * batch_size + batch_size]
|
|
|
|
correct_words_generated = False
|
|
|
|
words = None
|
|
|
|
subbatch_size = batch_size
|
|
cur_temp = 0
|
|
|
|
# process cards
|
|
while not correct_words_generated:
|
|
batches = [to_match[i:i + subbatch_size] for i in range(0, len(to_match), subbatch_size)]
|
|
words = []
|
|
for batch in batches:
|
|
words += extract_words_from_cards(batch, temp=cur_temp)
|
|
|
|
if len(words) != len(to_match):
|
|
print(f'generated words len ({len(words)}) != matches len ({len(to_match)})')
|
|
print(
|
|
f'Current Batch Size is {subbatch_size}, temp is {cur_temp}. If this happens repeatedly, try reducing the batch size.')
|
|
userinput = input('Try again? [Y/n/split/temp] > ')
|
|
if userinput.strip().lower() in ['n', 'no']:
|
|
print("aborting...")
|
|
exit(0)
|
|
elif userinput.strip().lower() in ['s', 'split']:
|
|
subbatch_size = max(1, subbatch_size // 2)
|
|
# generate batches
|
|
elif userinput.strip().lower() in ['t', 'temp']:
|
|
while True:
|
|
new_temp = input(f"cur temp: {cur_temp}. Input new temp (0-2) > ")
|
|
if is_float(new_temp) and 0 <= float(new_temp) <= 2:
|
|
cur_temp = float(new_temp)
|
|
break
|
|
else:
|
|
print("Must be numeric (float) between 0 and 2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
print("trying again...")
|
|
else:
|
|
correct_words_generated = True
|
|
|
|
# assert len(words) == len(to_match), f'generated words len ({len(words)}) != matches len ({len(to_match)})'
|
|
|
|
print(str(words))
|
|
|
|
filenames = []
|
|
|
|
for j, word in enumerate(words):
|
|
print(f'speaker-ifying word {word} ({j + i * batch_size} of {len(matches)})')
|
|
|
|
word = word.strip()
|
|
|
|
filename = f"{transform_string(word)}.mp3"
|
|
filenames.append(filename)
|
|
|
|
multisample = args.multisample
|
|
if len(word) < args.multisample_multiply_limit:
|
|
multisample *= args.multisample_multiply # generate 3x as many for short words / phrases
|
|
|
|
do_tts(word, SPEAKER_WAV, LANG, f'./{OUT_DIR}/{filename}', multisample=multisample)
|
|
|
|
for i, (question, answer) in enumerate(to_match):
|
|
out_content = out_content.replace(f"Q: {question}",
|
|
f"Q: {question} {'![[' if args.obsidian_format else f'![](./{OUT_DIR}/'}{filenames[i]}{']]' if args.obsidian_format else ')'}")
|
|
|
|
# open out file for writing
|
|
out_file = args.out_file
|
|
with open(out_file, 'w') as f:
|
|
f.write(out_content)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|