import os
import datetime
import traceback
import threading
import concurrent.futures
import gc
import math
import logging
import numpy as np
import multiprocessing as mp
from typing import Tuple, List, Optional, Dict, Any, Iterable, Union
from pysrt import SubRipTime, SubRipItem, SubRipFile
from sklearn.metrics import log_loss
from copy import deepcopy
from .network import Network
from .embedder import FeatureEmbedder
from .media_helper import MediaHelper
from .singleton import Singleton
from .subtitle import Subtitle
from .hyperparameters import Hyperparameters
from .exception import TerminalException
from .exception import NoFrameRateException
from .logger import Logger
[docs]class Predictor(metaclass=Singleton):
""" Predictor for working out the time to shift subtitles
"""
__MAX_SHIFT_IN_SECS = (
100
)
__MAX_CHARS_PER_SEC = (
50
) # Average 0.3 word per sec multiplies average 6 characters per word
__MAX_HEAD_ROOM = 20000 # Maximum duration without subtitle (10 minutes)
__SEGMENT_PREDICTION_TIMEOUT = 60 # Maximum waiting time in seconds when predicting each segment
__THREAD_QUEUE_SIZE = 8
__THREAD_NUMBER = 1 # Do not change
def __init__(self, **kwargs) -> None:
"""Feature predictor initialiser.
Keyword Arguments:
n_mfcc {int} -- The number of MFCC components (default: {13}).
frequency {float} -- The sample rate (default: {16000}).
hop_len {int} -- The number of samples per frame (default: {512}).
step_sample {float} -- The space (in seconds) between the begining of each sample (default: 1s / 25 FPS = 0.04s).
len_sample {float} -- The length in seconds for the input samples (default: {0.075}).
"""
self.__feature_embedder = FeatureEmbedder(**kwargs)
self.__LOGGER = Logger().get_logger(__name__)
self.__media_helper = MediaHelper()
[docs] def predict_single_pass(
self,
video_file_path: str,
subtitle_file_path: str,
weights_dir: str = os.path.join(os.path.dirname(__file__), "models", "training", "weights"),
) -> Tuple[List[SubRipItem], str, Union[np.ndarray, List[float]], Optional[float]]:
"""Predict time to shift with single pass
Arguments:
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The path to the subtitle file.
weights_dir {string} -- The the model weights directory.
Returns:
tuple -- The shifted subtitles, the audio file path and the voice probabilities of the original audio.
"""
weights_file_path = self.__get_weights_path(weights_dir)
audio_file_path = ""
frame_rate = None
try:
subs, audio_file_path, voice_probabilities = self.__predict(
video_file_path, subtitle_file_path, weights_file_path
)
try:
frame_rate = self.__media_helper.get_frame_rate(video_file_path)
self.__feature_embedder.step_sample = 1 / frame_rate
self.__on_frame_timecodes(subs)
except NoFrameRateException:
self.__LOGGER.warning("Cannot detect the frame rate for %s" % video_file_path)
return subs, audio_file_path, voice_probabilities, frame_rate
finally:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
[docs] def predict_dual_pass(
self,
video_file_path: str,
subtitle_file_path: str,
weights_dir: str = os.path.join(os.path.dirname(__file__), "models", "training", "weights"),
stretch: bool = False,
stretch_in_lang: str = "eng",
exit_segfail: bool = False,
) -> Tuple[List[SubRipItem], List[SubRipItem], Union[np.ndarray, List[float]], Optional[float]]:
"""Predict time to shift with single pass
Arguments:
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The path to the subtitle file.
weights_dir {string} -- The the model weights directory.
stretch {bool} -- True to stretch the subtitle segments (default: {False})
stretch_in_lang {str} -- The language used for stretching subtitles (default: {"eng"}).
exit_segfail {bool} -- True to exit on any segment alignment failures (default: {False})
Returns:
tuple -- The shifted subtitles, the globally shifted subtitles and the voice probabilities of the original audio.
"""
weights_file_path = self.__get_weights_path(weights_dir)
audio_file_path = ""
frame_rate = None
try:
subs, audio_file_path, voice_probabilities = self.__predict(
video_file_path, subtitle_file_path, weights_file_path
)
new_subs = self.__predict_2nd_pass(
audio_file_path,
subs,
weights_file_path=weights_file_path,
stretch=stretch,
stretch_in_lang=stretch_in_lang,
exit_segfail=exit_segfail,
)
try:
frame_rate = self.__media_helper.get_frame_rate(video_file_path)
self.__feature_embedder.step_sample = 1 / frame_rate
self.__on_frame_timecodes(new_subs)
except NoFrameRateException:
self.__LOGGER.warning("Cannot detect the frame rate for %s" % video_file_path)
self.__LOGGER.debug("Aligned segments generated")
return new_subs, subs, voice_probabilities, frame_rate
finally:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
[docs] def predict_plain_text(self, video_file_path: str, subtitle_file_path: str, stretch_in_lang: str = "eng") -> Tuple:
"""Predict time to shift with plain texts
Arguments:
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The path to the subtitle file.
stretch_in_lang {str} -- The language used for stretching subtitles (default: {"eng"}).
Returns:
tuple -- The shifted subtitles, the audio file path (None) and the voice probabilities of the original audio (None).
"""
from aeneas.executetask import ExecuteTask
from aeneas.task import Task
from aeneas.runtimeconfiguration import RuntimeConfiguration
from aeneas.logger import Logger as AeneasLogger
t = datetime.datetime.now()
audio_file_path = self.__media_helper.extract_audio(
video_file_path, True, 16000
)
self.__LOGGER.debug(
"[{}] Audio extracted after {}".format(
os.getpid(), str(datetime.datetime.now() - t)
)
)
root, _ = os.path.splitext(audio_file_path)
# Initialise a DTW alignment task
task_config_string = (
"task_language={}|os_task_file_format=srt|is_text_type=subtitles".format(stretch_in_lang)
)
runtime_config_string = "dtw_algorithm=stripe" # stripe or exact
task = Task(config_string=task_config_string)
try:
task.audio_file_path_absolute = audio_file_path
task.text_file_path_absolute = subtitle_file_path
task.sync_map_file_path_absolute = "{}.srt".format(root)
tee = False if self.__LOGGER.level == getattr(logging, 'DEBUG') else True
# Execute the task
ExecuteTask(
task=task,
rconf=RuntimeConfiguration(config_string=runtime_config_string),
logger=AeneasLogger(tee=tee),
).execute()
# Output new subtitle segment to a file
task.output_sync_map_file()
# Load the above subtitle segment
adjusted_subs = Subtitle.load(
task.sync_map_file_path_absolute
).subs
frame_rate = None
try:
frame_rate = self.__media_helper.get_frame_rate(video_file_path)
self.__feature_embedder.step_sample = 1 / frame_rate
self.__on_frame_timecodes(adjusted_subs)
except NoFrameRateException:
self.__LOGGER.warning("Cannot detect the frame rate for %s" % video_file_path)
return adjusted_subs, None, None, frame_rate
except KeyboardInterrupt:
raise TerminalException("Subtitle stretch interrupted by the user")
finally:
# Housekeep intermediate files
if task.audio_file_path_absolute is not None and os.path.exists(
task.audio_file_path_absolute
):
os.remove(task.audio_file_path_absolute)
if task.sync_map_file_path_absolute is not None and os.path.exists(task.sync_map_file_path_absolute):
os.remove(task.sync_map_file_path_absolute)
[docs] def get_log_loss(self, voice_probabilities: np.ndarray, subs: List[SubRipItem]) -> float:
"""Returns a single loss value on voice prediction
Arguments:
voice_probabilities {list} -- A list of probabilities of audio chunks being speech.
subs {list} -- A list of subtitle segments.
Returns:
float -- The loss value.
"""
subtitle_mask = Predictor.__get_subtitle_mask(self, subs)
if len(subtitle_mask) == 0:
raise TerminalException("Subtitle is empty")
# Adjust the voice duration when it is shorter than the subtitle duration
# so we can have room to shift the subtitle back and forth based on losses.
head_room = len(voice_probabilities) - len(subtitle_mask)
if head_room < 0:
self.__LOGGER.warning("Audio duration is shorter than the subtitle duration")
local_vp = np.vstack(
[
voice_probabilities,
[np.zeros(voice_probabilities.shape[1])] * (-head_room * 5),
]
)
result = log_loss(
subtitle_mask, local_vp[: len(subtitle_mask)], labels=[0, 1]
)
else:
result = log_loss(
subtitle_mask, voice_probabilities[: len(subtitle_mask)], labels=[0, 1]
)
self.__LOGGER.debug("Log loss: {}".format(result))
return result
[docs] def get_min_log_loss_and_index(self, voice_probabilities: np.ndarray, subs: SubRipFile) -> Tuple[float, int]:
"""Returns the minimum loss value and its shift position after going through all possible shifts.
Arguments:
voice_probabilities {list} -- A list of probabilities of audio chunks being speech.
subs {list} -- A list of subtitle segments.
Returns:
tuple -- The minimum loss value and its position.
"""
local_subs = deepcopy(subs)
local_subs.shift(seconds=-FeatureEmbedder.time_to_sec(subs[0].start))
subtitle_mask = Predictor.__get_subtitle_mask(self, local_subs)
if len(subtitle_mask) == 0:
raise TerminalException("Subtitle is empty")
# Adjust the voice duration when it is shorter than the subtitle duration
# so we can have room to shift the subtitle back and forth based on losses.
head_room = len(voice_probabilities) - len(subtitle_mask)
self.__LOGGER.debug("head room: {}".format(head_room))
if head_room < 0:
local_vp = np.vstack(
[
voice_probabilities,
[np.zeros(voice_probabilities.shape[1])] * (-head_room * 5),
]
)
else:
local_vp = voice_probabilities
head_room = len(local_vp) - len(subtitle_mask)
if head_room > Predictor.__MAX_HEAD_ROOM:
self.__LOGGER.error("head room: {}".format(head_room))
raise TerminalException(
"Maximum head room reached due to the suspicious audio or subtitle duration"
)
log_losses = []
self.__LOGGER.debug(
"Start calculating {} log loss(es)...".format(head_room)
)
for i in np.arange(0, head_room):
log_losses.append(
log_loss(
subtitle_mask,
local_vp[i:i + len(subtitle_mask)],
labels=[0, 1],
)
)
if log_losses:
min_log_loss = min(log_losses)
min_log_loss_idx = log_losses.index(min_log_loss)
else:
min_log_loss = None
min_log_loss_idx = 0
del local_vp
del log_losses
gc.collect()
return min_log_loss, min_log_loss_idx
@staticmethod
def _predict_in_multiprocesses(
self,
batch_idx: List[int],
segment_starts: List[str],
segment_ends: List[str],
weights_file_path: str,
audio_file_path: str,
subs: List[SubRipItem],
subs_copy: List[SubRipItem],
stretch: bool,
stretch_in_lang: str,
exit_segfail: bool,
) -> List[SubRipItem]:
subs_list = []
with _ThreadPoolExecutorLocal(
queue_size=Predictor.__THREAD_QUEUE_SIZE,
max_workers=Predictor.__THREAD_NUMBER
) as executor:
lock = threading.RLock()
network = self.__initialise_network(os.path.dirname(weights_file_path), self.__LOGGER)
futures = []
for segment_index in batch_idx:
futures.append(
executor.submit(
Predictor._predict_in_multithreads,
self,
segment_index,
segment_starts,
segment_ends,
weights_file_path,
audio_file_path,
subs,
subs_copy,
stretch,
stretch_in_lang,
exit_segfail,
lock,
network
)
)
for i, future in enumerate(futures):
try:
new_subs = future.result(timeout=Predictor.__SEGMENT_PREDICTION_TIMEOUT)
except concurrent.futures.TimeoutError as e:
self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT)
message = "Segment alignment timed out after {} seconds".format(Predictor.__SEGMENT_PREDICTION_TIMEOUT)
self.__LOGGER.error(message)
raise TerminalException(message) from e
except Exception as e:
self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT)
message = "Exception on segment alignment: {}\n{}".format(str(e), "".join(traceback.format_stack()))
self.__LOGGER.error(e, exc_info=True, stack_info=True)
traceback.print_tb(e.__traceback__)
if isinstance(e, TerminalException):
raise e
else:
raise TerminalException(message) from e
except KeyboardInterrupt:
self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT)
raise TerminalException("Segment alignment interrupted by the user")
else:
self.__LOGGER.debug("Segment aligned")
subs_list.extend(new_subs)
return subs_list
@staticmethod
def _predict_in_multithreads(
self,
segment_index: int,
segment_starts: List[str],
segment_ends: List[str],
weights_file_path: str,
audio_file_path: str,
subs: List[SubRipItem],
subs_copy: List[SubRipItem],
stretch: bool,
stretch_in_lang: str,
exit_segfail: bool,
lock: threading.RLock,
network: Network
) -> List[SubRipItem]:
segment_path = ""
try:
if segment_index == (len(segment_starts) - 1):
segment_path, segment_duration = self.__media_helper.extract_audio_from_start_to_end(
audio_file_path, segment_starts[segment_index], None
)
else:
segment_path, segment_duration = self.__media_helper.extract_audio_from_start_to_end(
audio_file_path,
segment_starts[segment_index],
segment_ends[segment_index],
)
subtitle_duration = FeatureEmbedder.time_to_sec(
subs[segment_index][len(subs[segment_index]) - 1].end
) - FeatureEmbedder.time_to_sec(subs[segment_index][0].start)
if segment_duration is None:
max_shift_secs = None
else:
max_shift_secs = segment_duration - subtitle_duration
if segment_index == 0:
previous_gap = 0.0
else:
previous_gap = FeatureEmbedder.time_to_sec(subs[segment_index][0].start) - FeatureEmbedder.time_to_sec(
subs[segment_index - 1][len(subs[segment_index - 1]) - 1].end
)
subs_new, _, voice_probabilities = self.__predict(
video_file_path=None,
subtitle_file_path=None,
weights_file_path=weights_file_path,
audio_file_path=segment_path,
subtitles=subs_copy[segment_index],
max_shift_secs=max_shift_secs,
previous_gap=previous_gap,
lock=lock,
network=network
)
del voice_probabilities
gc.collect()
if stretch:
subs_new = self.__adjust_durations(subs_new, audio_file_path, stretch_in_lang, lock)
self.__LOGGER.info("[{}] Segment {} stretched".format(os.getpid(), segment_index))
return subs_new
except Exception as e:
self.__LOGGER.error(
"[{}] Alignment failed for segment {}: {}\n{}".format(
os.getpid(), segment_index, str(e), "".join(traceback.format_stack())
)
)
traceback.print_tb(e.__traceback__)
if exit_segfail:
raise TerminalException("At least one of the segments failed on alignment. Exiting...") from e
return subs[segment_index]
finally:
# Housekeep intermediate files
if os.path.exists(segment_path):
os.remove(segment_path)
@staticmethod
def __minibatch(total: int, batch_size: int) -> Iterable[List[int]]:
batch: List = []
for i in range(total):
if len(batch) == batch_size:
yield batch
batch = []
batch.append(i)
if batch:
yield batch
@staticmethod
def __initialise_network(weights_dir: str, logger: logging.Logger) -> Network:
model_dir = weights_dir.replace("/weights", "/model").replace("\\weights", "\\model")
config_dir = weights_dir.replace("/weights", "/config").replace("\\weights", "\\config")
files = os.listdir(model_dir)
model_files = [
file
for file in files
if file.startswith("model")
]
files = os.listdir(config_dir)
hyperparams_files = [
file
for file in files
if file.startswith("hyperparameters")
]
if not model_files:
raise TerminalException(
"Cannot find model files at {}".format(weights_dir)
)
logger.debug("model files: {}".format(model_files))
logger.debug("config files: {}".format(hyperparams_files))
# Get the first file from the file lists
model_path = os.path.join(model_dir, model_files[0])
hyperparams_path = os.path.join(config_dir, hyperparams_files[0])
hyperparams = Hyperparameters.from_file(hyperparams_path)
return Network.get_from_model(model_path, hyperparams)
@staticmethod
def __get_weights_path(weights_dir: str) -> str:
files = os.listdir(weights_dir)
weights_files = [
file
for file in files
if file.startswith("weights")
]
if not weights_files:
raise TerminalException(
"Cannot find weights files at {}".format(weights_dir)
)
# Get the first file from the file lists
weights_path = os.path.join(weights_dir, weights_files[0])
return os.path.abspath(weights_path)
def __predict_2nd_pass(self, audio_file_path: str, subs: List[SubRipItem], weights_file_path: str, stretch: bool, stretch_in_lang: str, exit_segfail: bool) -> List[SubRipItem]:
"""This function uses divide and conquer to align partial subtitle with partial video.
Arguments:
audio_file_path {string} -- The file path of the original audio.
subs {list} -- A list of SubRip files.
weights_file_path {string} -- The file path of the weights file.
stretch {bool} -- True to stretch the subtitle segments.
stretch_in_lang {str} -- The language used for stretching subtitles.
exit_segfail {bool} -- True to exit on any segment alignment failures.
"""
segment_starts, segment_ends, subs = self.__media_helper.get_audio_segment_starts_and_ends(subs)
subs_copy = deepcopy(subs)
for index, sub in enumerate(subs):
self.__LOGGER.debug(
"Subtitle chunk #{0}: start time: {1} ------> end time: {2}".format(
index, sub[0].start, sub[len(sub) - 1].end
)
)
assert len(segment_starts) == len(
segment_ends
), "Segment start times and end times do not match"
assert len(segment_starts) == len(
subs
), "Segment size and subtitle size do not match"
subs_list = []
max_workers = math.ceil(float(os.getenv("MAX_WORKERS", mp.cpu_count() / 2)))
self.__LOGGER.debug("Number of workers: {}".format(max_workers))
with concurrent.futures.ProcessPoolExecutor(
max_workers=max_workers
) as executor:
batch_size = max(math.floor(len(segment_starts) / max_workers), 1)
futures = [
executor.submit(
Predictor._predict_in_multiprocesses,
self,
batch_idx,
segment_starts,
segment_ends,
weights_file_path,
audio_file_path,
subs,
subs_copy,
stretch,
stretch_in_lang,
exit_segfail
)
for batch_idx in Predictor.__minibatch(len(segment_starts), batch_size)
]
for i, future in enumerate(futures):
try:
subs_list.extend(future.result(timeout=Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size))
except concurrent.futures.TimeoutError as e:
self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size)
message = "Batch alignment timed out after {} seconds".format(Predictor.__SEGMENT_PREDICTION_TIMEOUT)
self.__LOGGER.error(message)
raise TerminalException(message) from e
except Exception as e:
self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size)
message = "Exception on batch alignment: {}\n{}".format(str(e), "".join(traceback.format_stack()))
self.__LOGGER.error(e, exc_info=True, stack_info=True)
traceback.print_tb(e.__traceback__)
if isinstance(e, TerminalException):
raise e
else:
raise TerminalException(message) from e
except KeyboardInterrupt:
self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size)
raise TerminalException("Batch alignment interrupted by the user")
else:
self.__LOGGER.debug("Batch aligned")
subs_list = [sub_item for sub_item in subs_list]
self.__LOGGER.debug("All segments aligned")
return subs_list
def __cancel_futures(self, futures: List[concurrent.futures.Future], timeout: int) -> None:
for future in futures:
future.cancel()
concurrent.futures.wait(futures, timeout=timeout)
def __get_subtitle_mask(self, subs: List[SubRipItem]) -> np.ndarray:
pos = self.__feature_embedder.time_to_position(subs[len(subs) - 1].end) - 1
subtitle_mask = np.zeros(pos if pos > 0 else 0)
for sub in subs:
start_pos = self.__feature_embedder.time_to_position(sub.start)
end_pos = self.__feature_embedder.time_to_position(sub.end)
for i in np.arange(start_pos, end_pos):
if i < len(subtitle_mask):
subtitle_mask[i] = 1
return subtitle_mask
def __on_frame_timecodes(self, subs: List[SubRipItem]) -> None:
for sub in subs:
millis_per_frame = self.__feature_embedder.step_sample * 1000
new_start_millis = round(int(str(sub.start).split(",")[1]) / millis_per_frame + 0.5) * millis_per_frame
new_start = str(sub.start).split(",")[0] + "," + str(int(new_start_millis)).zfill(3)
new_end_millis = round(int(str(sub.end).split(",")[1]) / millis_per_frame - 0.5) * millis_per_frame
new_end = str(sub.end).split(",")[0] + "," + str(int(new_end_millis)).zfill(3)
sub.start = SubRipTime.coerce(new_start)
sub.end = SubRipTime.coerce(new_end)
def __adjust_durations(self, subs: List[SubRipItem], audio_file_path: str, stretch_in_lang: str, lock: threading.RLock) -> List[SubRipItem]:
from aeneas.executetask import ExecuteTask
from aeneas.task import Task
from aeneas.runtimeconfiguration import RuntimeConfiguration
from aeneas.logger import Logger as AeneasLogger
# Initialise a DTW alignment task
task_config_string = (
"task_language={}|os_task_file_format=srt|is_text_type=subtitles".format(stretch_in_lang)
)
runtime_config_string = "dtw_algorithm=stripe" # stripe or exact
task = Task(config_string=task_config_string)
try:
with lock:
segment_path, _ = self.__media_helper.extract_audio_from_start_to_end(
audio_file_path,
str(subs[0].start),
str(subs[len(subs) - 1].end),
)
# Create a text file for DTW alignments
root, _ = os.path.splitext(segment_path)
text_file_path = "{}.txt".format(root)
with open(text_file_path, "w", encoding="utf8") as text_file:
for sub_new in subs:
text_file.write(sub_new.text)
text_file.write(os.linesep * 2)
task.audio_file_path_absolute = segment_path
task.text_file_path_absolute = text_file_path
task.sync_map_file_path_absolute = "{}.srt".format(root)
tee = self.__LOGGER.level == getattr(logging, 'DEBUG')
# Execute the task
ExecuteTask(
task=task,
rconf=RuntimeConfiguration(config_string=runtime_config_string),
logger=AeneasLogger(tee=tee),
).execute()
# Output new subtitle segment to a file
task.output_sync_map_file()
# Load the above subtitle segment
adjusted_subs = Subtitle.load(
task.sync_map_file_path_absolute
).subs
for index, sub_new_loaded in enumerate(adjusted_subs):
sub_new_loaded.index = subs[index].index
adjusted_subs.shift(
seconds=self.__media_helper.get_duration_in_seconds(
start=None, end=str(subs[0].start)
)
)
return adjusted_subs
except KeyboardInterrupt:
raise TerminalException("Subtitle stretch interrupted by the user")
finally:
# Housekeep intermediate files
if task.audio_file_path_absolute is not None and os.path.exists(
task.audio_file_path_absolute
):
os.remove(task.audio_file_path_absolute)
if task.text_file_path_absolute is not None and os.path.exists(
task.text_file_path_absolute
):
os.remove(task.text_file_path_absolute)
if task.sync_map_file_path_absolute is not None and os.path.exists(task.sync_map_file_path_absolute):
os.remove(task.sync_map_file_path_absolute)
def __predict(
self,
video_file_path: Optional[str],
subtitle_file_path: Optional[str],
weights_file_path: str,
audio_file_path: Optional[str] = None,
subtitles: Optional[SubRipFile] = None,
max_shift_secs: Optional[float] = None,
previous_gap: Optional[float] = None,
lock: Optional[threading.RLock] = None,
network: Optional[Network] = None
) -> Tuple[List[SubRipItem], str, np.ndarray]:
"""Shift out-of-sync subtitle cues by sending the audio track of an video to the trained network.
Arguments:
video_file_path {string} -- The file path of the original video.
subtitle_file_path {string} -- The file path of the out-of-sync subtitles.
weights_file_path {string} -- The file path of the weights file.
Keyword Arguments:
audio_file_path {string} -- The file path of the original audio (default: {None}).
subtitles {list} -- The list of SubRip files (default: {None}).
max_shift_secs {float} -- The maximum seconds by which subtitle cues can be shifted (default: {None}).
previous_gap {float} -- The duration between the start time of the audio segment and the start time of the subtitle segment (default: {None}).
Returns:
tuple -- The shifted subtitles, the audio file path and the voice probabilities of the original audio.
"""
if network is None:
network = self.__initialise_network(os.path.dirname(weights_file_path), self.__LOGGER)
result: Dict[str, Any] = {}
pred_start = datetime.datetime.now()
if audio_file_path is not None:
result["audio_file_path"] = audio_file_path
elif video_file_path is not None:
t = datetime.datetime.now()
audio_file_path = self.__media_helper.extract_audio(
video_file_path, True, 16000
)
self.__LOGGER.debug(
"[{}] Audio extracted after {}".format(
os.getpid(), str(datetime.datetime.now() - t)
)
)
result["video_file_path"] = video_file_path
else:
raise TerminalException("Neither audio nor video is passed in")
if subtitle_file_path is not None:
subs = Subtitle.load(subtitle_file_path).subs
result["subtitle_file_path"] = subtitle_file_path
elif subtitles is not None:
subs = subtitles
else:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
raise TerminalException("ERROR: No subtitles passed in")
if lock is not None:
with lock:
try:
train_data, labels = self.__feature_embedder.extract_data_and_label_from_audio(
audio_file_path, None, subtitles=subs
)
except TerminalException:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
raise
else:
try:
train_data, labels = self.__feature_embedder.extract_data_and_label_from_audio(
audio_file_path, None, subtitles=subs
)
except TerminalException:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
raise
train_data = np.array([np.rot90(val) for val in train_data])
train_data = train_data - np.mean(train_data, axis=0)
result["time_load_dataset"] = str(datetime.datetime.now() - pred_start)
result["X_shape"] = train_data.shape
# Load neural network
input_shape = (train_data.shape[1], train_data.shape[2])
self.__LOGGER.debug("[{}] input shape: {}".format(os.getpid(), input_shape))
# Network class is not thread safe so a new graph is created for each thread
pred_start = datetime.datetime.now()
if lock is not None:
with lock:
try:
self.__LOGGER.debug("[{}] Start predicting...".format(os.getpid()))
voice_probabilities = network.get_predictions(train_data, weights_file_path)
except Exception as e:
self.__LOGGER.error("[{}] Prediction failed: {}\n{}".format(os.getpid(), str(e), "".join(traceback.format_stack())))
traceback.print_tb(e.__traceback__)
raise TerminalException("Prediction failed") from e
finally:
del train_data
del labels
gc.collect()
else:
try:
self.__LOGGER.debug("[{}] Start predicting...".format(os.getpid()))
voice_probabilities = network.get_predictions(train_data, weights_file_path)
except Exception as e:
self.__LOGGER.error(
"[{}] Prediction failed: {}\n{}".format(os.getpid(), str(e), "".join(traceback.format_stack())))
traceback.print_tb(e.__traceback__)
raise TerminalException("Prediction failed") from e
finally:
del train_data
del labels
gc.collect()
if len(voice_probabilities) <= 0:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
raise TerminalException(
"ERROR: Audio is too short and no voice was detected"
)
result["time_predictions"] = str(datetime.datetime.now() - pred_start)
original_start = FeatureEmbedder.time_to_sec(subs[0].start)
shifted_subs = deepcopy(subs)
subs.shift(seconds=-original_start)
self.__LOGGER.info("[{}] Aligning subtitle with video...".format(os.getpid()))
if lock is not None:
with lock:
min_log_loss, min_log_loss_pos = self.get_min_log_loss_and_index(
voice_probabilities, subs
)
else:
min_log_loss, min_log_loss_pos = self.get_min_log_loss_and_index(
voice_probabilities, subs
)
pos_to_delay = min_log_loss_pos
result["loss"] = min_log_loss
self.__LOGGER.info("[{}] Subtitle aligned".format(os.getpid()))
if subtitle_file_path is not None: # for the first pass
seconds_to_shift = (
self.__feature_embedder.position_to_duration(pos_to_delay) - original_start
)
elif subtitles is not None: # for each in second pass
seconds_to_shift = self.__feature_embedder.position_to_duration(pos_to_delay) - previous_gap if previous_gap is not None else 0.0
else:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
raise ValueError("ERROR: No subtitles passed in")
if abs(seconds_to_shift) > Predictor.__MAX_SHIFT_IN_SECS:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
raise TerminalException(
"Average shift duration ({} secs) have been reached".format(
Predictor.__MAX_SHIFT_IN_SECS
)
)
result["seconds_to_shift"] = seconds_to_shift
result["original_start"] = original_start
total_elapsed_time = str(datetime.datetime.now() - pred_start)
result["time_sync"] = total_elapsed_time
self.__LOGGER.debug("[{}] Statistics: {}".format(os.getpid(), result))
self.__LOGGER.debug("[{}] Total Time: {}".format(os.getpid(), total_elapsed_time))
self.__LOGGER.debug(
"[{}] Seconds to shift: {}".format(os.getpid(), seconds_to_shift)
)
# For each subtitle chunk, its end time should not be later than the end time of the audio segment
if max_shift_secs is not None and seconds_to_shift <= max_shift_secs:
shifted_subs.shift(seconds=seconds_to_shift)
elif max_shift_secs is not None and seconds_to_shift > max_shift_secs:
self.__LOGGER.warning(
"[{}] Maximum {} seconds shift has reached".format(os.getpid(), max_shift_secs)
)
shifted_subs.shift(seconds=max_shift_secs)
else:
shifted_subs.shift(seconds=seconds_to_shift)
self.__LOGGER.debug("[{}] Subtitle shifted".format(os.getpid()))
return shifted_subs, audio_file_path, voice_probabilities
class _ThreadPoolExecutorLocal:
def __init__(self, queue_size: int, max_workers: int):
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self.semaphore = threading.BoundedSemaphore(queue_size + max_workers)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.executor.shutdown(True)
def submit(self, fn, *args, **kwargs):
self.semaphore.acquire()
try:
future = self.executor.submit(fn, *args, **kwargs)
except Exception:
self.semaphore.release()
raise
else:
future.add_done_callback(lambda x: self.semaphore.release())
return future