Skip to content
Snippets Groups Projects
Commit 2625e390 authored by MARMORET Axel's avatar MARMORET Axel
Browse files

Fixing issue #1: code compatible with recent versions of numpy (> 1.25) and...

Fixing issue #1: code compatible with recent versions of numpy (> 1.25) and preparing for compatibility with high versions of tensorly.
parent c6929e71
No related branches found
No related tags found
No related merge requests found
......@@ -495,7 +495,10 @@ def one_ntd_step(tensor, ranks, in_core, in_factors, norm_tensor,
temp = tl.tenalg.multi_mode_dot(core, elemprod, skip=mode)
# this line can be computed with tensor contractions
con_modes = [i for i in range(tl.ndim(tensor)) if i != mode]
UtU = tl.tenalg.contract(temp, con_modes, core, con_modes)
# Depending on the version of tensorly, use different version of the API
UtU = tl.tenalg.contract(temp, con_modes, core, con_modes) # Tensorly v0.6.0
# UtU = tl.tenalg.tensordot(temp, core, con_modes) # Replacing contraction because they had been removed in Tensorly v0.9.0
#UtU = unfold(temp, mode)@tl.transpose(unfold(core, mode))
# UtM
......@@ -503,7 +506,10 @@ def one_ntd_step(tensor, ranks, in_core, in_factors, norm_tensor,
temp = tl.tenalg.multi_mode_dot(tensor, factors, skip=mode, transpose = True)
# again, computable by tensor contractions
#MtU = unfold(temp, mode)@tl.transpose(unfold(core, mode))
MtU = tl.tenalg.contract(temp, con_modes, core, con_modes)
# Depedning on the version of tensorly, use different version of the API
MtU = tl.tenalg.contract(temp, con_modes, core, con_modes) # Tensorly v0.6.0
# MtU = tl.tenalg.tensordot(temp, core, con_modes) # Replacing contraction because they had been removed in Tensorly v0.9.0
UtM = tl.transpose(MtU)
......@@ -529,7 +535,7 @@ def one_ntd_step(tensor, ranks, in_core, in_factors, norm_tensor,
# better implementation: reuse the computation of temp !
# Also reuse elemprod form last update
all_MtX = tl.tenalg.mode_dot(temp, tl.transpose(factors[modes_list[-1]]), modes_list[-1])
all_MtM = tl.copy(elemprod)
all_MtM = elemprod.copy()
all_MtM[modes_list[-1]] = factors[modes_list[-1]].T@factors[modes_list[-1]]
#all_MtM = np.array([fac.T@fac for fac in factors])
......
......@@ -26,7 +26,7 @@ setuptools.setup(
license='BSD',
install_requires=[
'nimfa',
'numpy >= 1.18.0, <1.24', # Starting from version 1.24, the NTD algorithm (and potentially the others) does not work anymore. Realted to issue #1 https://github.com/ax-le/nn-fac/issues/1. Should be fixed in the future.
'numpy >= 1.18.0', # Starting from version 1.24, the NTD algorithm (and potentially the others) does not work anymore. Realted to issue #1 https://github.com/ax-le/nn-fac/issues/1. Should be fixed in the future.
'scipy >= 0.13.0',
'tensorly == 0.6.0',
],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment