Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
527a410
remove double call
hndgzkn Dec 16, 2022
eda2180
set inner_bounds and worker_support as class attribute
hndgzkn Dec 16, 2022
7eb3d5f
update get_local_coordinate signature to contain seg_bounds and remov…
hndgzkn Dec 16, 2022
b6973fd
update get_global_coordinate signature to contain seg_bounds and rem…
hndgzkn Dec 16, 2022
230c460
update is_contained_coordinate signature to contain seg_bounds and se…
hndgzkn Dec 16, 2022
0a22a0a
update checj_area_contained signature to have seg_bounds, seg_inner_b…
hndgzkn Dec 16, 2022
e081315
update get_touched_overlap_slices signature to have seg_bounds, seg_i…
hndgzkn Dec 16, 2022
0b9dc7d
update get_padding_to_overlap to have seg_bounds, seg_inner_bounds
hndgzkn Dec 16, 2022
602bffc
move unnecessary checks under relevant if condition
hndgzkn Dec 29, 2022
8ac9d61
minor
hndgzkn Dec 29, 2022
206b625
reorganize _select_coordinate
hndgzkn Dec 29, 2022
747b19f
update get_touched_segments, coordinate_update, _select_coordinate si…
hndgzkn Dec 29, 2022
c32654b
check if number of workers is greater than 1
hndgzkn Dec 30, 2022
b507a74
separate worker segmentation and local segmentaiton
hndgzkn Jan 3, 2023
94963a2
remove unused methods from WorkerSegmentation
hndgzkn Jan 3, 2023
385c9e1
simplify workersegmentation constructor and remove unused method
hndgzkn Jan 3, 2023
db45452
simplify localsegmentation constructor
hndgzkn Jan 3, 2023
f652322
create parent Segmentation class and move common methods there
hndgzkn Jan 3, 2023
86229ec
remove unused methods from localsegmentation
hndgzkn Jan 3, 2023
42ced63
remove inner parameter from all bounds methods for local bounds as it…
hndgzkn Jan 3, 2023
404a78f
remove unnecessary find_segment from workersegmentation get_touched_s…
hndgzkn Jan 3, 2023
8a074ee
modify get_seg_support signature to use bounds for WorkerSegmentation
hndgzkn Jan 3, 2023
7d6d8bc
remove get_seg_support for LocalSegmentation
hndgzkn Jan 3, 2023
f18c4cc
move get_global_coordinate to Segmentation
hndgzkn Jan 3, 2023
ebc14be
update get_seg_slice signature to use bounds and move it to Segmentation
hndgzkn Jan 3, 2023
e846f70
move common methods to Segmentation
hndgzkn Jan 4, 2023
ea17972
fix Segmentation tests
hndgzkn Jan 4, 2023
24d414b
fix linting
hndgzkn Jan 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions dicodile/update_z/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dicodile.utils.csc import _dense_transpose_convolve, reconstruct
from dicodile.utils import check_random_state
from dicodile.utils import debug_flags as flags
from dicodile.utils.segmentation import Segmentation
from dicodile.utils.segmentation import WorkerSegmentation
from dicodile.utils.csc import compute_ztz, compute_ztX
from dicodile.utils.shape_helpers import get_valid_support
from dicodile.utils.order_iterator import get_order_iterator
Expand Down Expand Up @@ -86,7 +86,7 @@ def coordinate_descent(X_i, D, reg, z0=None, DtD=None, n_seg='auto',
if n_seg == 'auto':
n_seg = np.array(valid_support) // (2 * np.array(atom_support) - 1)
n_seg = tuple(np.maximum(1, n_seg))
segments = Segmentation(n_seg, signal_support=valid_support)
segments = WorkerSegmentation(n_seg, signal_support=valid_support)

# Pre-compute constants for maintaining the auxillary variable beta and
# compute the coordinate update values.
Expand Down Expand Up @@ -291,7 +291,8 @@ def _init_beta(X_i, D, reg, z_i=None, constants={}, z_positive=False,
return beta, dz_opt, dE


def _select_coordinate(dz_opt, dE, segments, i_seg, strategy, order=None):
def _select_coordinate(dz_opt, dE, segments, i_seg, strategy, seg_bounds,
order=None):
"""Pick a coordinate to update

Parameters
Expand Down Expand Up @@ -322,21 +323,21 @@ def _select_coordinate(dz_opt, dE, segments, i_seg, strategy, order=None):
if strategy in ['random', 'cyclic-r', 'cyclic']:
k0, *pt0 = next(order)
else:
if strategy in ['greedy', 'gs-r']:
seg_slice = segments.get_seg_slice(i_seg, inner=True)
dz_opt_seg = dz_opt[seg_slice]
i0 = abs(dz_opt_seg).argmax()
seg_slice = segments.get_seg_slice(seg_bounds)

if strategy in ['greedy', 'gs-r']:
d_seg = dz_opt[seg_slice]
elif strategy == 'gs-q':
seg_slice = segments.get_seg_slice(i_seg, inner=True)
dE_seg = dE[seg_slice]
i0 = abs(dE_seg).argmax()
d_seg = dE[seg_slice]

i0 = abs(d_seg).argmax()

# TODO: broken~~~!!!
k0, *pt0 = np.unravel_index(i0, dz_opt_seg.shape)
# k0, *pt0 = tuple(fast_unravel(i0, dz_opt_seg.shape))
pt0 = segments.get_global_coordinate(i_seg, pt0)
k0, *pt0 = np.unravel_index(i0, d_seg.shape)
pt0 = segments.get_global_coordinate(pt0, seg_bounds)

dz = dz_opt[(k0, *pt0)]

return k0, pt0, dz


Expand Down Expand Up @@ -435,7 +436,7 @@ def compute_dE(dz_opt, beta, z_hat, reg):
dz_opt * (z_hat + .5 * dz_opt - beta)
# l1 term
+ reg * (abs(z_hat) - abs(z_hat + dz_opt))
)
)


def _check_convergence(segments, tol, iteration, dz_opt, n_coordinates,
Expand Down
23 changes: 12 additions & 11 deletions dicodile/update_z/dicod.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..utils.csc import _is_rank1, compute_objective
from ..utils.debugs import main_check_beta
from .coordinate_descent import STRATEGIES
from ..utils.segmentation import Segmentation
from ..utils.segmentation import WorkerSegmentation
from .coordinate_descent import coordinate_descent
from ..utils.mpi import broadcast_array, recv_reduce_sum_array
from ..utils.shape_helpers import get_valid_support, find_grid_size
Expand Down Expand Up @@ -256,14 +256,16 @@ def _send_signal(workers, w_world, atom_support, X, z0=None):
X_info["workers_topology"] = w_world, n_workers // w_world

# compute a segmentation for the image,
workers_segments = Segmentation(n_seg=X_info['workers_topology'],
signal_support=valid_support,
overlap=overlap)
workers_segments = WorkerSegmentation(n_seg=X_info['workers_topology'],
signal_support=valid_support,
overlap=overlap)

# Make sure that each worker has at least a segment of twice the size of
# the dictionary. If this is not the case, the algorithm is not valid as it
# is possible to have interference with workers that are not neighbors.
worker_support = workers_segments.get_seg_support(0, inner=True)
worker_inner_bounds = workers_segments.get_seg_bounds(0, inner=True)
worker_support = workers_segments.get_seg_support(worker_inner_bounds)

msg = ("The size of the support in each worker is smaller than twice the "
"size of the atom support. The algorithm is does not converge in "
"this condition. Reduce the number of cores.\n"
Expand All @@ -279,10 +281,10 @@ def _send_signal(workers, w_world, atom_support, X, z0=None):
X = np.array(X, dtype='d')

for i_seg in range(n_workers):
seg_bounds = workers_segments.get_seg_bounds(i_seg)
if z0 is not None:
worker_slice = workers_segments.get_seg_slice(i_seg)
worker_slice = workers_segments.get_seg_slice(seg_bounds)
_send_array(workers.comm, i_seg, z0[worker_slice])
seg_bounds = workers_segments.get_seg_bounds(i_seg)
X_worker_slice = (Ellipsis,) + tuple([
slice(start, end + size_atom_ax - 1)
for (start, end), size_atom_ax in zip(seg_bounds, atom_support)
Expand Down Expand Up @@ -361,13 +363,12 @@ def recv_z_hat(comm, n_atoms, workers_segments):
inner = not flags.GET_OVERLAP_Z_HAT
z_hat = np.empty((n_atoms, *valid_support), dtype='d')
for i_seg in range(workers_segments.effective_n_seg):
worker_support = workers_segments.get_seg_support(
i_seg, inner=inner)
bounds = workers_segments.get_seg_bounds(i_seg, inner=inner)
worker_support = workers_segments.get_seg_support(bounds)
z_worker = np.zeros((n_atoms,) + worker_support, 'd')
comm.Recv([z_worker.ravel(), MPI.DOUBLE], source=i_seg,
tag=constants.TAG_ROOT + i_seg)
worker_slice = workers_segments.get_seg_slice(
i_seg, inner=inner)
worker_slice = workers_segments.get_seg_slice(bounds)
z_hat[worker_slice] = z_worker

return z_hat
Expand Down
14 changes: 9 additions & 5 deletions dicodile/utils/debugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def main_check_beta(comm, workers_segments):
sum_beta = np.empty(1, 'd')
value = []
for i_worker in range(workers_segments.effective_n_seg):

pt = workers_segments.get_local_coordinate(i_worker, pt_global)
if workers_segments.is_contained_coordinate(i_worker, pt):
worker_bounds = workers_segments.get_seg_bounds(i_worker)
pt = workers_segments.get_local_coordinate(pt_global,
worker_bounds)
if workers_segments.is_contained_coordinate(pt,
worker_bounds):
comm.Recv([sum_beta, MPI.DOUBLE], source=i_worker,
tag=constants.TAG_ROOT + i_probe)
value.append(sum_beta[0])
Expand All @@ -44,8 +46,10 @@ def worker_check_beta(rank, workers_segments, beta, D_shape):

global_test_points = get_global_test_points(workers_segments)
for i_probe, pt_global in enumerate(global_test_points):
pt = workers_segments.get_local_coordinate(rank, pt_global)
if workers_segments.is_contained_coordinate(rank, pt):
worker_bounds = workers_segments.get_seg_bounds(rank)
pt = workers_segments.get_local_coordinate(pt_global,
worker_bounds)
if workers_segments.is_contained_coordinate(pt, worker_bounds):
beta_slice = (Ellipsis,) + pt
sum_beta = np.array(beta[beta_slice].sum(), dtype='d')

Expand Down
Loading