Skip to content
Snippets Groups Projects
Commit fdc6923f authored by FARRUGIA Nicolas's avatar FARRUGIA Nicolas
Browse files

correct bug with resampling

parent e4923157
Branches
Tags
No related merge requests found
......@@ -5,7 +5,7 @@ import torch
from .model import SoundNetEncoding_conv
from torchaudio.transforms import Resample
resampling = Resample(48000, 22000)
def load_model(model_file_path, device=None):
if device is None:
......@@ -19,6 +19,7 @@ def load_model(model_file_path, device=None):
# Set model weights using checkpoint file
model.load_state_dict(modeldict['checkpoint'])
model = model.to(device)
model.sample_rate = 48000 # Input sample rate
model.scene_embedding_size = 1024
model.timestamp_embedding_size = 128
......@@ -27,6 +28,8 @@ def load_model(model_file_path, device=None):
def get_scene_embeddings(x, model):
device = x.device
resampling = Resample(48000, 22000).to(device)
x = resampling(x)
audio_length = x.shape[1]
......@@ -37,7 +40,7 @@ def get_scene_embeddings(x, model):
if audio_length < minimum_length:
device = x.device
x = torch.cat((x, torch.zeros(batch_size, 1,minimum_length - audio_length,1).to(device)), dim=2)
with torch.no_grad():
......@@ -51,7 +54,8 @@ def get_scene_embeddings(x, model):
def get_timestamp_embeddings(x, model):
device = x.device
resampling = Resample(48000, 22000).to(device)
x = resampling(x)
audio_length = x.shape[1]
batch_size = x.shape[0]
......@@ -62,7 +66,7 @@ def get_timestamp_embeddings(x, model):
if audio_length < minimum_length:
batch_size = x.shape[0]
device = x.device
x = torch.cat((x, torch.zeros(batch_size, 1,minimum_length - audio_length,1).to(device)), dim=2)
with torch.no_grad():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment