diff --git a/doc/source/machine_vectors.rst b/doc/source/machine_vectors.rst index ba1f8b4f06..a154f68db3 100644 --- a/doc/source/machine_vectors.rst +++ b/doc/source/machine_vectors.rst @@ -80,9 +80,21 @@ Access and conversions Create vector from distinct entries. -.. function:: vec4n vec4d_convert_limited_vec4n(vec4d a) +.. function:: vec1n vec1d_convert_limited_vec1n(vec1d a) + vec2n vec2d_convert_limited_vec2n(vec2d a) + vec4n vec4d_convert_limited_vec4n(vec4d a) + + Given that each entry in the input vector is an exact integer in + ``[0, 2^{52})``, convert it to :type:`ulong`. + Note that ``vec1d`` and ``vec2d`` functions are only available on NEON. + +.. function:: vec2d vec2n_convert_limited_vec2d(vec2n a) + vec4d vec4n_convert_limited_vec4d(vec4n a) vec8d vec8n_convert_limited_vec8d(vec8n a) + The inverse conversion, from :type:`ulong` entries to exact + ``double`` entries. Requires the same assumption. + Permutations ------------------------------------------------------------------------------- diff --git a/src/fft_small.h b/src/fft_small.h index 3646e9b50d..736f01b830 100644 --- a/src/fft_small.h +++ b/src/fft_small.h @@ -218,6 +218,7 @@ FLINT_FORCE_INLINE ulong sd_fft_ctx_data_size(ulong L) return n_pow2(L); } +/* Return the address of block `I` in `d`, where each block has `BLK_SZ` entries. */ FLINT_FORCE_INLINE double* sd_fft_ctx_blk_index(double* d, ulong I) { return d + sd_fft_ctx_blk_offset(I); diff --git a/src/fft_small/mpn_helpers.c b/src/fft_small/mpn_helpers.c index 74fa6547ac..8b037d8a3c 100644 --- a/src/fft_small/mpn_helpers.c +++ b/src/fft_small/mpn_helpers.c @@ -12,7 +12,9 @@ #include "fft_small.h" #include "crt_helpers.h" -/* transpose a block */ +/* For each `0 <= l < np`, reduce values in the `I`-th block of `d + l*dstride` + modulo `Rffts[l].p`, convert the result to integers, and write the result + to `Xs + l*BLK_SZ`. */ void _convert_block( ulong* Xs, sd_fft_ctx_struct* Rffts, double* d, ulong dstride, diff --git a/src/fft_small/mpn_mul.c b/src/fft_small/mpn_mul.c index 739986fa2e..bca664d324 100644 --- a/src/fft_small/mpn_mul.c +++ b/src/fft_small/mpn_mul.c @@ -40,7 +40,8 @@ The following profiles are hardcoded. 8 192 7 6 Here np is the number of primes {p[0], p[1], ...p[np-1]}. By default the -following 50-bit primes are used: +following 50-bit primes are used (produced by repeatedly applying next_fft_number +on DEFAULT_PRIME): p[0] = 1108307720798209 p[1] = 659706976665601 @@ -899,7 +900,7 @@ DEFINE_IT(8, 7, 6) /* Specialized helper function, currently only called from mpn_ctx_init. Assume p is odd and p - 1 has high 2-valuation, return some number q - (not necessarily prime) less than p such that q - 1 has high 2-valuation. + (not necessarily prime) such that q - 1 also has high 2-valuation. */ static ulong next_fft_number(ulong p) { @@ -910,9 +911,12 @@ static ulong next_fft_number(ulong p) if (bits < 15) flint_throw(FLINT_ERROR, "(%s)\n", __func__); if (n_nbits(q) == bits) + // Best case: q - 1 has the same bit length and 2-valuation as p - 1 return q; if (l < 5) - return n_pow2(bits - 2) + 1; + return n_pow2(bits - 2) + 1; // Worst case: drop the bit length by 1 + // Second-best case: keep the bit length, but drop the 2-valuation by 1 + // (this is the only case where q > p) return n_pow2(bits) - n_pow2(l - 1) + 1; } diff --git a/src/fft_small/nmod_poly_mul.c b/src/fft_small/nmod_poly_mul.c index f8b48d6d4b..0970bc0e72 100644 --- a/src/fft_small/nmod_poly_mul.c +++ b/src/fft_small/nmod_poly_mul.c @@ -12,6 +12,7 @@ #include "thread_pool.h" #include "thread_support.h" #include "mpn_extras.h" +#include "ulong_extras.h" #include "nmod.h" #include "nmod_vec.h" #include "nmod_poly.h" @@ -317,8 +318,9 @@ static void _crt_1( prime itself or small enough to be a valid FFT prime. */ FLINT_ASSERT(mod.n <= (UWORD(1) << 50)); - /* Todo: generalize _convert_block to allow using the fast path - for all FFT primes. */ + /* This applies in the `direct_fft` branch, or when `mod.n` is already + in the context. In the latter case, `s2worker_func` passes + `Rffts = ffts + offset`, making the matched prime `Rffts[0]`. */ have_fft_prime = (mod.n == Rffts[0].mod.n); if (!have_fft_prime) { @@ -710,6 +712,9 @@ typedef struct { } s1worker_struct; +/* Reduce `const ulong* b` into `bbuf` and FFT-transform it in place. + Used as a helper for `s1worker_func` in order to process + `bbuf` and `abuf` in parallel, assuming `!squaring`. */ static void extra_func(void* varg) { s1worker_struct* X = (s1worker_struct*) varg; @@ -719,6 +724,14 @@ static void extra_func(void* varg) sd_fft_trunc(Q, X->bbuf, X->depth, X->btrunc, X->ztrunc); } +/* compute convolutions modulo the selected precomputed FFT primes. + Specifically, for each `start_pi <= i < stop_pi`, reduce + `a + stride*i` into `abuf + stride*i`, and multiply it by `b`, + modulo `ffts[i + X->offset].p`. The output is written in-place to + `abuf + stride*i`. + If `squaring` is true, `abuf` is squared, and `bbuf` is not used. + If `flint_request_threads` returns the requested number of threads, + then each `s1worker_func` processes exactly one prime. */ static void s1worker_func(void* varg) { s1worker_struct* X = (s1worker_struct*) varg; @@ -736,28 +749,22 @@ static void s1worker_func(void* varg) double* bbuf = X->bbuf; sd_fft_ctx_struct* Q = X->ffts + ioff; - if (!X->squaring) + if (nworkers > 0) { - if (nworkers > 0) - { - X->ioff = ioff; - thread_pool_wake(global_thread_pool, handles[0], 0, extra_func, X); - } - else - { - _mod(bbuf, X->btrunc, X->b, X->bn, Q, X->mod); - sd_fft_trunc(Q, bbuf, X->depth, X->btrunc, X->ztrunc); - } + X->ioff = ioff; // read by extra_func on the worker thread + thread_pool_wake(global_thread_pool, handles[0], 0, extra_func, X); + } + else if (!X->squaring) + { + _mod(bbuf, X->btrunc, X->b, X->bn, Q, X->mod); + sd_fft_trunc(Q, bbuf, X->depth, X->btrunc, X->ztrunc); } _mod(abuf, X->atrunc, X->a, X->an, Q, X->mod); sd_fft_trunc(Q, abuf, X->depth, X->atrunc, X->ztrunc); - if (!X->squaring) - { - if (nworkers > 0) - thread_pool_wait(global_thread_pool, handles[0]); - } + if (nworkers > 0) + thread_pool_wait(global_thread_pool, handles[0]); ulong cop = X->np == 1 ? 1 : *crt_data_co_prime_red(X->crts + X->np - 1, ioff); NMOD_RED2(m, cop >> (FLINT_BITS - X->depth), cop << X->depth, Q->mod); @@ -790,9 +797,11 @@ typedef struct { ulong* z, ulong zl, ulong zi_start, ulong zi_stop, sd_fft_ctx_struct* Rffts, double* d, ulong dstride, crt_data_struct* Rcrts, ulong min_an_bn, - nmod_t mod); + nmod_t mod); /* f = _crt_{np} for 1 <= np <= 4 */ } s2worker_struct; +/* Computes CRT for an output range, writing to `z[zi - zl]` for + `start_zi <= zi < stop_zi`. */ static void s2worker_func(void* varg) { s2worker_struct* X = (s2worker_struct*) varg; @@ -801,6 +810,38 @@ static void s2worker_func(void* varg) X->stride, X->crts + X->offset, X->min_an_bn, X->mod); } +/* Return whether to compute the FFT directly modulo `mod`, rather than + modulo precomputed FFT primes followed by CRT reconstruction. + Direct computation requires `mod.n` to be prime and `mod.n - 1` + to have sufficiently high 2-valuation. + This is usually faster when CRT would need multiple primes, but it pays + the setup cost of initializing and clearing a temporary `sd_fft_ctx_t`. */ +static int +_nmod_poly_should_directly_fft(ulong bn, ulong depth, nmod_t mod) +{ + if (bn < 1500) + return 0; + + if (mod.n <= 2 || mod.n > (UWORD(1) << 50)) + return 0; + + if (NMOD_BITS(mod) < 20) + return 0; + + if (!fft_small_mulmod_satisfies_bounds(mod.n)) + /* should be implied by the 50-bit bound, but just in case */ + return 0; + + if (depth > SD_FFT_CTX_W2TAB_SIZE) + /* unlikely, the convolution length would have to be massive */ + return 0; + + if (n_trailing_zeros(mod.n - 1) < n_max(depth, SD_FFT_CTX_W2TAB_INIT)) + return 0; + + return n_is_prime(mod.n); /* check the most expensive condition last */ +} + void _nmod_poly_mul_mid_mpn_ctx( ulong* z, ulong zl, ulong zh, const ulong* a, ulong an, @@ -814,8 +855,11 @@ void _nmod_poly_mul_mid_mpn_ctx( ulong atrunc, btrunc, ztrunc; ulong i, np, depth, stride; double* buf; + sd_fft_ctx_t direct_fft; // initialized with mod.n in direct mode int squaring; + direct_fft->w2tab[0] = NULL; // use direct_fft branch iff this is not NULL + FLINT_ASSERT(an > 0); FLINT_ASSERT(bn > 0); @@ -839,7 +883,26 @@ void _nmod_poly_mul_mid_mpn_ctx( FLINT_ASSERT(zl < zh); FLINT_ASSERT(zh <= zn); - /* first see if mod.n is on of R->ffts[i].mod.n */ + atrunc = n_round_up(an, BLK_SZ); + btrunc = n_round_up(bn, BLK_SZ); + ztrunc = n_round_up(zn, BLK_SZ); + /* + if there is a power of two 2^d between zh and zn with good wrap around + i.e. max(an, bn, zh) <= 2^d <= zn with zn - 2^d <= zl + then use d as the depth, otherwise the usual with no wrap around + */ + depth = n_flog2(zn); + i = n_pow2(depth); + if (atrunc <= i && btrunc <= i && zh <= i && i <= zn && zn <= zl + i) + { + ztrunc = i; + } + else + { + depth = n_max(LG_BLK_SZ, n_clog2(ztrunc)); + } + + /* first see if mod.n is one of R->ffts[i].mod.n */ if (modbits == 50) { for (i = 0; i < MPN_CTX_NCRTS; i++) @@ -853,6 +916,14 @@ void _nmod_poly_mul_mid_mpn_ctx( } } + if (_nmod_poly_should_directly_fft(bn, depth, mod)) + { + sd_fft_ctx_init_prime(direct_fft, mod.n); + offset = 0; + np = 1; + goto got_np_and_offset; + } + /* need prod_of_primes >= blen * 4^modbits */ for (np = 1; np < 4; np++) { @@ -870,25 +941,6 @@ void _nmod_poly_mul_mid_mpn_ctx( got_np_and_offset: - atrunc = n_round_up(an, BLK_SZ); - btrunc = n_round_up(bn, BLK_SZ); - ztrunc = n_round_up(zn, BLK_SZ); - /* - if there is a power of two 2^d between zh and zn with good wrap around - i.e. max(an, bn, zh) <= 2^d <= zn with zn - 2^d <= zl - then use d as the depth, otherwise the usual with no wrap around - */ - depth = n_flog2(zn); - i = n_pow2(depth); - if (atrunc <= i && btrunc <= i && zh <= i && i <= zn && zn <= zl + i) - { - ztrunc = i; - } - else - { - depth = n_max(LG_BLK_SZ, n_clog2(ztrunc)); - } - stride = n_round_up(sd_fft_ctx_data_size(depth), 128); ulong want_threads; @@ -923,7 +975,7 @@ void _nmod_poly_mul_mid_mpn_ctx( X->an = an; X->b = b; X->bn = bn; - X->ffts = R->ffts; + X->ffts = direct_fft->w2tab[0] != NULL ? direct_fft : R->ffts; X->crts = R->crts; X->mod = mod; X->squaring = squaring; @@ -956,7 +1008,7 @@ void _nmod_poly_mul_mid_mpn_ctx( X->buf = buf; X->offset = offset; X->stride = stride; - X->ffts = R->ffts; + X->ffts = direct_fft->w2tab[0] != NULL ? direct_fft : R->ffts; X->crts = R->crts; X->mod = mod; X->min_an_bn = FLINT_MIN(an, bn); @@ -970,6 +1022,9 @@ void _nmod_poly_mul_mid_mpn_ctx( thread_pool_wait(global_thread_pool, handles[i - 1]); flint_give_back_threads(handles, nworkers); + + if (direct_fft->w2tab[0] != NULL) + sd_fft_ctx_clear(direct_fft); } #if 0 @@ -1003,6 +1058,10 @@ static void _nmod_poly_mul_mod_xpnm1_naive( } #endif +/* +Set `z` to the cyclic convolution of `a` and `b` modulo `mod` +with length `N = 2^depth`. +*/ static void _nmod_poly_mul_mod_xpnm1( ulong* z, ulong ztrunc, const ulong* a, ulong an, @@ -1022,7 +1081,6 @@ static void _nmod_poly_mul_mod_xpnm1( FLINT_ASSERT(ztrunc <= N); /* first see if mod.n is one of R->ffts[i].mod.n */ - if (modbits == 50) { for (i = 0; i < MPN_CTX_NCRTS; i++) @@ -1150,6 +1208,9 @@ typedef struct { nmod_t mod; } s1pworker_struct; +/* similar to `s1worker_func`, but assume `bbuf` is already FFT-transformed. + If `flint_request_threads` returns the requested number of threads, + then each `s1pworker_func` processes exactly one prime. */ static void s1pworker_func(void* varg) { s1pworker_struct* X = (s1pworker_struct*) varg; @@ -1189,8 +1250,7 @@ void _mul_precomp_init( FLINT_ASSERT(bn > 0); - /* first see if mod.n is on of R->ffts[i].mod.n */ - + /* first see if mod.n is one of R->ffts[i].mod.n */ if (modbits == 50) { for (i = 0; i < MPN_CTX_NCRTS; i++) diff --git a/src/fft_small/profile/p-nmod_poly_direct_fft.c b/src/fft_small/profile/p-nmod_poly_direct_fft.c new file mode 100644 index 0000000000..003ddd3736 --- /dev/null +++ b/src/fft_small/profile/p-nmod_poly_direct_fft.c @@ -0,0 +1,137 @@ +/* + Benchmark _nmod_poly_mul_mid_mpn_ctx when the modulus itself + can serve as the FFT prime. +*/ + +#include "nmod.h" +#include "nmod_poly.h" +#include "fft_small.h" +#include "profiler.h" +#include "ulong_extras.h" + +static ulong +find_fft_prime(ulong bits, ulong depth) +{ + ulong lo, hi, step, k, klo, khi; + + if (bits <= depth || bits > 50) + return 0; + + lo = UWORD(1) << (bits - 1); + hi = UWORD(1) << bits; + step = UWORD(1) << depth; + klo = (lo <= 1) ? 1 : ((lo - 1 + step - 1) >> depth); + khi = (hi - 2) >> depth; + + if ((klo & 1) == 0) + klo++; + + for (k = klo; k <= khi; k += 2) + { + ulong p = (k << depth) + 1; + + if (n_is_prime(p) && fft_small_mulmod_satisfies_bounds(p)) + return p; + } + + return 0; +} + +static double +bench_one(ulong p, ulong n, int threads, flint_rand_t state, mpn_ctx_t R) +{ + nmod_t mod; + ulong * a, * b, * z; + ulong i, reps; + timeit_t timer; + double best; + + nmod_init(&mod, p); + + a = FLINT_ARRAY_ALLOC(n, ulong); + b = FLINT_ARRAY_ALLOC(n, ulong); + z = FLINT_ARRAY_ALLOC(2*n - 1, ulong); + + for (i = 0; i < n; i++) + { + a[i] = n_randint(state, p); + b[i] = n_randint(state, p); + } + + flint_set_num_threads(threads); + + _nmod_poly_mul_mid_mpn_ctx(z, 0, 2*n - 1, a, n, b, n, mod, R); + + reps = 1; + if (n <= 1500) + reps = 8; + else if (n <= 3000) + reps = 5; + else if (n <= 6000) + reps = 3; + + best = 1e100; + for (i = 0; i < 3; i++) + { + double t; + + timeit_start_us(timer); + for (ulong j = 0; j < reps; j++) + _nmod_poly_mul_mid_mpn_ctx(z, 0, 2*n - 1, a, n, b, n, mod, R); + timeit_stop_us(timer); + + t = ((double) timer->wall)/reps; + best = FLINT_MIN(best, t); + } + + flint_free(a); + flint_free(b); + flint_free(z); + + return best; +} + +int main(void) +{ + const ulong ns[] = {1500, 3000, 6000, 12000}; + const int thread_counts[] = {1, 8}; + flint_rand_t state; + mpn_ctx_t R; + + flint_rand_init(state); + mpn_ctx_init(R, UWORD(0x0003f00000000001)); + + flint_printf("threads,bits,p,n,depth,usec\n"); + + for (ulong ti = 0; ti < sizeof(thread_counts)/sizeof(thread_counts[0]); ti++) + { + int threads = thread_counts[ti]; + + for (ulong ni = 0; ni < sizeof(ns)/sizeof(ns[0]); ni++) + { + ulong n = ns[ni]; + ulong zn = 2*n - 1; + ulong ztrunc = n_round_up(zn, BLK_SZ); + ulong depth = n_max(LG_BLK_SZ, n_clog2(ztrunc)); + + for (ulong bits = depth + 1; bits <= 50; bits++) + { + ulong p = find_fft_prime(bits, depth); + + if (p != 0) + { + double t = bench_one(p, n, threads, state, R); + flint_printf("%d,%wu,%wu,%wu,%wu,%.3f\n", + threads, bits, p, n, depth, t); + fflush(stdout); + } + } + } + } + + mpn_ctx_clear(R); + flint_rand_clear(state); + flint_cleanup(); + + return 0; +} diff --git a/src/fft_small/test/t-nmod_poly_mul.c b/src/fft_small/test/t-nmod_poly_mul.c index 88112864f9..752322d2a5 100644 --- a/src/fft_small/test/t-nmod_poly_mul.c +++ b/src/fft_small/test/t-nmod_poly_mul.c @@ -23,6 +23,89 @@ TEST_FUNCTION_START(_nmod_poly_mul_mid_mpn_ctx, state) mpn_ctx_init(R, UWORD(0x0003f00000000001)); + /* Check the direct FFT branch. */ + { + ulong * a, * b, * c, * d; + ulong an, bn, zn, zl, zh, sz, i; + + nmod_init(&mod, UWORD(7340033)); /* 7*2^20 + 1 */ + + for (slong reps = 0; reps < 4; reps++) + { + flint_set_num_threads(1 + reps); + + an = 1800 + 37*reps; + bn = 1600 + 19*reps; + zn = an + bn - 1; + zl = reps == 0 ? 0 : 123 + reps; + zh = reps == 1 ? zn : zn - 17*reps; + sz = FLINT_MAX(zn, zh); + + a = FLINT_ARRAY_ALLOC(an, ulong); + b = FLINT_ARRAY_ALLOC(bn, ulong); + c = FLINT_ARRAY_ALLOC(sz, ulong); + d = FLINT_ARRAY_ALLOC(sz, ulong); + + for (i = 0; i < an; i++) + a[i] = n_randint(state, mod.n); + + for (i = 0; i < bn; i++) + b[i] = n_randint(state, mod.n); + + flint_mpn_zero(c, sz); + _nmod_poly_mul_KS(c, a, an, b, bn, mod); + _nmod_poly_mul_mid_mpn_ctx(d, zl, zh, a, an, b, bn, mod, R); + + for (i = zl; i < zh; i++) + { + if (c[i] != d[i-zl]) + { + flint_printf("(direct fft) mulmid error at index %wu\n", i); + flint_printf("zl=%wu, zh=%wu, an=%wu, bn=%wu\n", zl, zh, an, bn); + flint_printf("mod: %wu\n", mod.n); + flint_abort(); + } + } + + flint_free(a); + flint_free(b); + flint_free(c); + flint_free(d); + } + + an = 1700; + zn = an + an - 1; + zl = 41; + zh = zn - 29; + sz = zn; + + a = FLINT_ARRAY_ALLOC(an, ulong); + c = FLINT_ARRAY_ALLOC(sz, ulong); + d = FLINT_ARRAY_ALLOC(sz, ulong); + + for (i = 0; i < an; i++) + a[i] = n_randint(state, mod.n); + + flint_mpn_zero(c, sz); + _nmod_poly_mul_KS(c, a, an, a, an, mod); + _nmod_poly_mul_mid_mpn_ctx(d, zl, zh, a, an, a, an, mod, R); + + for (i = zl; i < zh; i++) + { + if (c[i] != d[i-zl]) + { + flint_printf("(direct fft squaring) mulmid error at index %wu\n", i); + flint_printf("zl=%wu, zh=%wu, an=%wu\n", zl, zh, an); + flint_printf("mod: %wu\n", mod.n); + flint_abort(); + } + } + + flint_free(a); + flint_free(c); + flint_free(d); + } + /* (slow) test bug where 3 instead of 4 primes were used */ #if 0 { diff --git a/src/thread_support/thread_support.c b/src/thread_support/thread_support.c index c50f0879b5..86193fbb7f 100644 --- a/src/thread_support/thread_support.c +++ b/src/thread_support/thread_support.c @@ -85,8 +85,10 @@ int flint_restore_thread_affinity(void) return thread_pool_restore_affinity(global_thread_pool); } -/* Takes in the *thread limit* but returns the number of **handles**. That is, - * the maximum return value is `thread_limit - 1`. */ +/* Wrapper around `thread_pool_request` using `global_thread_pool`, and + dynamically allocate the handle array. + Takes in the *thread limit* but returns the number of **handles**. That is, + the maximum return value is `thread_limit - 1`. */ slong flint_request_threads(thread_pool_handle ** handles, slong thread_limit) { slong num_handles = 0; @@ -113,6 +115,8 @@ slong flint_request_threads(thread_pool_handle ** handles, slong thread_limit) return num_handles; } +/* Wrapper around `thread_pool_give_back` to use `global_thread_pool` + and free the memory allocated by `flint_request_threads`. */ void flint_give_back_threads(thread_pool_handle * handles, slong num_handles) { slong i;