Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion doc/source/machine_vectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,21 @@ Access and conversions

Create vector from distinct entries.

.. function:: vec4n vec4d_convert_limited_vec4n(vec4d a)
.. function:: vec1n vec1d_convert_limited_vec1n(vec1d a)
vec2n vec2d_convert_limited_vec2n(vec2d a)
vec4n vec4d_convert_limited_vec4n(vec4d a)

Given that each entry in the input vector is an exact integer in
``[0, 2^{52})``, convert it to :type:`ulong`.
Note that ``vec1d`` and ``vec2d`` functions are only available on NEON.

.. function:: vec2d vec2n_convert_limited_vec2d(vec2n a)
vec4d vec4n_convert_limited_vec4d(vec4n a)
vec8d vec8n_convert_limited_vec8d(vec8n a)

The inverse conversion, from :type:`ulong` entries to exact
``double`` entries. Requires the same assumption.

Permutations
-------------------------------------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions src/fft_small.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ FLINT_FORCE_INLINE ulong sd_fft_ctx_data_size(ulong L)
return n_pow2(L);
}

/* Return the address of block `I` in `d`, where each block has `BLK_SZ` entries. */
FLINT_FORCE_INLINE double* sd_fft_ctx_blk_index(double* d, ulong I)
{
return d + sd_fft_ctx_blk_offset(I);
Expand Down
4 changes: 3 additions & 1 deletion src/fft_small/mpn_helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
#include "fft_small.h"
#include "crt_helpers.h"

/* transpose a block */
/* For each `0 <= l < np`, reduce values in the `I`-th block of `d + l*dstride`
modulo `Rffts[l].p`, convert the result to integers, and write the result
to `Xs + l*BLK_SZ`. */
void _convert_block(
ulong* Xs,
sd_fft_ctx_struct* Rffts, double* d, ulong dstride,
Expand Down
10 changes: 7 additions & 3 deletions src/fft_small/mpn_mul.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ The following profiles are hardcoded.
8 192 7 6

Here np is the number of primes {p[0], p[1], ...p[np-1]}. By default the
following 50-bit primes are used:
following 50-bit primes are used (produced by repeatedly applying next_fft_number
on DEFAULT_PRIME):

p[0] = 1108307720798209
p[1] = 659706976665601
Expand Down Expand Up @@ -899,7 +900,7 @@ DEFINE_IT(8, 7, 6)
/*
Specialized helper function, currently only called from mpn_ctx_init.
Assume p is odd and p - 1 has high 2-valuation, return some number q
(not necessarily prime) less than p such that q - 1 has high 2-valuation.
(not necessarily prime) such that q - 1 also has high 2-valuation.
*/
static ulong next_fft_number(ulong p)
{
Expand All @@ -910,9 +911,12 @@ static ulong next_fft_number(ulong p)
if (bits < 15)
flint_throw(FLINT_ERROR, "(%s)\n", __func__);
if (n_nbits(q) == bits)
// Best case: q - 1 has the same bit length and 2-valuation as p - 1
return q;
if (l < 5)
return n_pow2(bits - 2) + 1;
return n_pow2(bits - 2) + 1; // Worst case: drop the bit length by 1
// Second-best case: keep the bit length, but drop the 2-valuation by 1
// (this is the only case where q > p)
return n_pow2(bits) - n_pow2(l - 1) + 1;
}

Expand Down
148 changes: 104 additions & 44 deletions src/fft_small/nmod_poly_mul.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "thread_pool.h"
#include "thread_support.h"
#include "mpn_extras.h"
#include "ulong_extras.h"
#include "nmod.h"
#include "nmod_vec.h"
#include "nmod_poly.h"
Expand Down Expand Up @@ -317,8 +318,9 @@ static void _crt_1(
prime itself or small enough to be a valid FFT prime. */
FLINT_ASSERT(mod.n <= (UWORD(1) << 50));

/* Todo: generalize _convert_block to allow using the fast path
for all FFT primes. */
/* This applies in the `direct_fft` branch, or when `mod.n` is already
in the context. In the latter case, `s2worker_func` passes
`Rffts = ffts + offset`, making the matched prime `Rffts[0]`. */
have_fft_prime = (mod.n == Rffts[0].mod.n);
if (!have_fft_prime)
{
Expand Down Expand Up @@ -710,6 +712,9 @@ typedef struct {
} s1worker_struct;


/* Reduce `const ulong* b` into `bbuf` and FFT-transform it in place.
Used as a helper for `s1worker_func` in order to process
`bbuf` and `abuf` in parallel, assuming `!squaring`. */
static void extra_func(void* varg)
{
s1worker_struct* X = (s1worker_struct*) varg;
Expand All @@ -719,6 +724,14 @@ static void extra_func(void* varg)
sd_fft_trunc(Q, X->bbuf, X->depth, X->btrunc, X->ztrunc);
}

/* compute convolutions modulo the selected precomputed FFT primes.
Specifically, for each `start_pi <= i < stop_pi`, reduce
`a + stride*i` into `abuf + stride*i`, and multiply it by `b`,
modulo `ffts[i + X->offset].p`. The output is written in-place to
`abuf + stride*i`.
If `squaring` is true, `abuf` is squared, and `bbuf` is not used.
If `flint_request_threads` returns the requested number of threads,
then each `s1worker_func` processes exactly one prime. */
static void s1worker_func(void* varg)
{
s1worker_struct* X = (s1worker_struct*) varg;
Expand All @@ -736,28 +749,22 @@ static void s1worker_func(void* varg)
double* bbuf = X->bbuf;
sd_fft_ctx_struct* Q = X->ffts + ioff;

if (!X->squaring)
if (nworkers > 0)
{
if (nworkers > 0)
{
X->ioff = ioff;
thread_pool_wake(global_thread_pool, handles[0], 0, extra_func, X);
}
else
{
_mod(bbuf, X->btrunc, X->b, X->bn, Q, X->mod);
sd_fft_trunc(Q, bbuf, X->depth, X->btrunc, X->ztrunc);
}
X->ioff = ioff; // read by extra_func on the worker thread
thread_pool_wake(global_thread_pool, handles[0], 0, extra_func, X);
}
else if (!X->squaring)
{
_mod(bbuf, X->btrunc, X->b, X->bn, Q, X->mod);
sd_fft_trunc(Q, bbuf, X->depth, X->btrunc, X->ztrunc);
}

_mod(abuf, X->atrunc, X->a, X->an, Q, X->mod);
sd_fft_trunc(Q, abuf, X->depth, X->atrunc, X->ztrunc);

if (!X->squaring)
{
if (nworkers > 0)
thread_pool_wait(global_thread_pool, handles[0]);
}
if (nworkers > 0)
thread_pool_wait(global_thread_pool, handles[0]);

ulong cop = X->np == 1 ? 1 : *crt_data_co_prime_red(X->crts + X->np - 1, ioff);
NMOD_RED2(m, cop >> (FLINT_BITS - X->depth), cop << X->depth, Q->mod);
Expand Down Expand Up @@ -790,9 +797,11 @@ typedef struct {
ulong* z, ulong zl, ulong zi_start, ulong zi_stop,
sd_fft_ctx_struct* Rffts, double* d, ulong dstride,
crt_data_struct* Rcrts, ulong min_an_bn,
nmod_t mod);
nmod_t mod); /* f = _crt_{np} for 1 <= np <= 4 */
} s2worker_struct;

/* Computes CRT for an output range, writing to `z[zi - zl]` for
`start_zi <= zi < stop_zi`. */
static void s2worker_func(void* varg)
{
s2worker_struct* X = (s2worker_struct*) varg;
Expand All @@ -801,6 +810,38 @@ static void s2worker_func(void* varg)
X->stride, X->crts + X->offset, X->min_an_bn, X->mod);
}

/* Return whether to compute the FFT directly modulo `mod`, rather than
modulo precomputed FFT primes followed by CRT reconstruction.
Direct computation requires `mod.n` to be prime and `mod.n - 1`
to have sufficiently high 2-valuation.
This is usually faster when CRT would need multiple primes, but it pays
the setup cost of initializing and clearing a temporary `sd_fft_ctx_t`. */
static int
_nmod_poly_should_directly_fft(ulong bn, ulong depth, nmod_t mod)
{
if (bn < 1500)
return 0;

if (mod.n <= 2 || mod.n > (UWORD(1) << 50))
return 0;

if (NMOD_BITS(mod) < 20)
return 0;

if (!fft_small_mulmod_satisfies_bounds(mod.n))
/* should be implied by the 50-bit bound, but just in case */
return 0;

if (depth > SD_FFT_CTX_W2TAB_SIZE)
/* unlikely, the convolution length would have to be massive */
return 0;

if (n_trailing_zeros(mod.n - 1) < n_max(depth, SD_FFT_CTX_W2TAB_INIT))
return 0;

return n_is_prime(mod.n); /* check the most expensive condition last */
}

void _nmod_poly_mul_mid_mpn_ctx(
ulong* z, ulong zl, ulong zh,
const ulong* a, ulong an,
Expand All @@ -814,8 +855,11 @@ void _nmod_poly_mul_mid_mpn_ctx(
ulong atrunc, btrunc, ztrunc;
ulong i, np, depth, stride;
double* buf;
sd_fft_ctx_t direct_fft; // initialized with mod.n in direct mode
int squaring;

direct_fft->w2tab[0] = NULL; // use direct_fft branch iff this is not NULL

FLINT_ASSERT(an > 0);
FLINT_ASSERT(bn > 0);

Expand All @@ -839,7 +883,26 @@ void _nmod_poly_mul_mid_mpn_ctx(
FLINT_ASSERT(zl < zh);
FLINT_ASSERT(zh <= zn);

/* first see if mod.n is on of R->ffts[i].mod.n */
atrunc = n_round_up(an, BLK_SZ);
btrunc = n_round_up(bn, BLK_SZ);
ztrunc = n_round_up(zn, BLK_SZ);
/*
if there is a power of two 2^d between zh and zn with good wrap around
i.e. max(an, bn, zh) <= 2^d <= zn with zn - 2^d <= zl
then use d as the depth, otherwise the usual with no wrap around
*/
depth = n_flog2(zn);
i = n_pow2(depth);
if (atrunc <= i && btrunc <= i && zh <= i && i <= zn && zn <= zl + i)
{
ztrunc = i;
}
else
{
depth = n_max(LG_BLK_SZ, n_clog2(ztrunc));
}

/* first see if mod.n is one of R->ffts[i].mod.n */
if (modbits == 50)
{
for (i = 0; i < MPN_CTX_NCRTS; i++)
Expand All @@ -853,6 +916,14 @@ void _nmod_poly_mul_mid_mpn_ctx(
}
}

if (_nmod_poly_should_directly_fft(bn, depth, mod))
{
sd_fft_ctx_init_prime(direct_fft, mod.n);
offset = 0;
np = 1;
goto got_np_and_offset;
}

/* need prod_of_primes >= blen * 4^modbits */
for (np = 1; np < 4; np++)
{
Expand All @@ -870,25 +941,6 @@ void _nmod_poly_mul_mid_mpn_ctx(

got_np_and_offset:

atrunc = n_round_up(an, BLK_SZ);
btrunc = n_round_up(bn, BLK_SZ);
ztrunc = n_round_up(zn, BLK_SZ);
/*
if there is a power of two 2^d between zh and zn with good wrap around
i.e. max(an, bn, zh) <= 2^d <= zn with zn - 2^d <= zl
then use d as the depth, otherwise the usual with no wrap around
*/
depth = n_flog2(zn);
i = n_pow2(depth);
if (atrunc <= i && btrunc <= i && zh <= i && i <= zn && zn <= zl + i)
{
ztrunc = i;
}
else
{
depth = n_max(LG_BLK_SZ, n_clog2(ztrunc));
}

stride = n_round_up(sd_fft_ctx_data_size(depth), 128);

ulong want_threads;
Expand Down Expand Up @@ -923,7 +975,7 @@ void _nmod_poly_mul_mid_mpn_ctx(
X->an = an;
X->b = b;
X->bn = bn;
X->ffts = R->ffts;
X->ffts = direct_fft->w2tab[0] != NULL ? direct_fft : R->ffts;
X->crts = R->crts;
X->mod = mod;
X->squaring = squaring;
Expand Down Expand Up @@ -956,7 +1008,7 @@ void _nmod_poly_mul_mid_mpn_ctx(
X->buf = buf;
X->offset = offset;
X->stride = stride;
X->ffts = R->ffts;
X->ffts = direct_fft->w2tab[0] != NULL ? direct_fft : R->ffts;
X->crts = R->crts;
X->mod = mod;
X->min_an_bn = FLINT_MIN(an, bn);
Expand All @@ -970,6 +1022,9 @@ void _nmod_poly_mul_mid_mpn_ctx(
thread_pool_wait(global_thread_pool, handles[i - 1]);

flint_give_back_threads(handles, nworkers);

if (direct_fft->w2tab[0] != NULL)
sd_fft_ctx_clear(direct_fft);
}

#if 0
Expand Down Expand Up @@ -1003,6 +1058,10 @@ static void _nmod_poly_mul_mod_xpnm1_naive(
}
#endif

/*
Set `z` to the cyclic convolution of `a` and `b` modulo `mod`
with length `N = 2^depth`.
*/
static void _nmod_poly_mul_mod_xpnm1(
ulong* z, ulong ztrunc,
const ulong* a, ulong an,
Expand All @@ -1022,7 +1081,6 @@ static void _nmod_poly_mul_mod_xpnm1(
FLINT_ASSERT(ztrunc <= N);

/* first see if mod.n is one of R->ffts[i].mod.n */

if (modbits == 50)
{
for (i = 0; i < MPN_CTX_NCRTS; i++)
Expand Down Expand Up @@ -1150,6 +1208,9 @@ typedef struct {
nmod_t mod;
} s1pworker_struct;

/* similar to `s1worker_func`, but assume `bbuf` is already FFT-transformed.
If `flint_request_threads` returns the requested number of threads,
then each `s1pworker_func` processes exactly one prime. */
static void s1pworker_func(void* varg)
{
s1pworker_struct* X = (s1pworker_struct*) varg;
Expand Down Expand Up @@ -1189,8 +1250,7 @@ void _mul_precomp_init(

FLINT_ASSERT(bn > 0);

/* first see if mod.n is on of R->ffts[i].mod.n */

/* first see if mod.n is one of R->ffts[i].mod.n */
if (modbits == 50)
{
for (i = 0; i < MPN_CTX_NCRTS; i++)
Expand Down
Loading
Loading