11#include " random_generator.hpp"
22#include " logging.hpp"
3- #include < future>
3+ #include " misc.hpp"
4+ #include < vector>
5+ #include < mutex>
46
57namespace Groth16 {
68
@@ -46,114 +48,84 @@ std::unique_ptr<Prover<Engine>> makeProver(
4648template <typename Engine>
4749std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement *wtns) {
4850
49- #ifdef USE_OPENMP
51+ ThreadPool &threadPool = ThreadPool::defaultPool ();
52+
5053 LOG_TRACE (" Start Multiexp A" );
5154 uint32_t sW = sizeof (wtns[0 ]);
5255 typename Engine::G1Point pi_a;
53- E.g1 .multiMulByScalar (pi_a, pointsA, (uint8_t *)wtns, sW , nVars);
56+ E.g1 .multiMulByScalarMSM (pi_a, pointsA, (uint8_t *)wtns, sW , nVars);
5457 std::ostringstream ss2;
5558 ss2 << " pi_a: " << E.g1 .toString (pi_a);
5659 LOG_DEBUG (ss2);
5760
5861 LOG_TRACE (" Start Multiexp B1" );
5962 typename Engine::G1Point pib1;
60- E.g1 .multiMulByScalar (pib1, pointsB1, (uint8_t *)wtns, sW , nVars);
63+ E.g1 .multiMulByScalarMSM (pib1, pointsB1, (uint8_t *)wtns, sW , nVars);
6164 std::ostringstream ss3;
6265 ss3 << " pib1: " << E.g1 .toString (pib1);
6366 LOG_DEBUG (ss3);
6467
6568 LOG_TRACE (" Start Multiexp B2" );
6669 typename Engine::G2Point pi_b;
67- E.g2 .multiMulByScalar (pi_b, pointsB2, (uint8_t *)wtns, sW , nVars);
70+ E.g2 .multiMulByScalarMSM (pi_b, pointsB2, (uint8_t *)wtns, sW , nVars);
6871 std::ostringstream ss4;
6972 ss4 << " pi_b: " << E.g2 .toString (pi_b);
7073 LOG_DEBUG (ss4);
7174
7275 LOG_TRACE (" Start Multiexp C" );
7376 typename Engine::G1Point pi_c;
74- E.g1 .multiMulByScalar (pi_c, pointsC, (uint8_t *)((uint64_t )wtns + (nPublic +1 )*sW ), sW , nVars-nPublic-1 );
77+ E.g1 .multiMulByScalarMSM (pi_c, pointsC, (uint8_t *)((uint64_t )wtns + (nPublic +1 )*sW ), sW , nVars-nPublic-1 );
7578 std::ostringstream ss5;
7679 ss5 << " pi_c: " << E.g1 .toString (pi_c);
7780 LOG_DEBUG (ss5);
78- #else
79- LOG_TRACE (" Start Multiexp A" );
80- uint32_t sW = sizeof (wtns[0 ]);
81- typename Engine::G1Point pi_a;
82- auto pA_future = std::async ([&]() {
83- E.g1 .multiMulByScalar (pi_a, pointsA, (uint8_t *)wtns, sW , nVars);
84- });
85-
86- LOG_TRACE (" Start Multiexp B1" );
87- typename Engine::G1Point pib1;
88- auto pB1_future = std::async ([&]() {
89- E.g1 .multiMulByScalar (pib1, pointsB1, (uint8_t *)wtns, sW , nVars);
90- });
91-
92- LOG_TRACE (" Start Multiexp B2" );
93- typename Engine::G2Point pi_b;
94- auto pB2_future = std::async ([&]() {
95- E.g2 .multiMulByScalar (pi_b, pointsB2, (uint8_t *)wtns, sW , nVars);
96- });
97-
98- LOG_TRACE (" Start Multiexp C" );
99- typename Engine::G1Point pi_c;
100- auto pC_future = std::async ([&]() {
101- E.g1 .multiMulByScalar (pi_c, pointsC, (uint8_t *)((uint64_t )wtns + (nPublic +1 )*sW ), sW , nVars-nPublic-1 );
102- });
103- #endif
10481
10582 LOG_TRACE (" Start Initializing a b c A" );
10683 auto a = new typename Engine::FrElement[domainSize];
10784 auto b = new typename Engine::FrElement[domainSize];
10885 auto c = new typename Engine::FrElement[domainSize];
10986
110- #pragma omp parallel for
111- for (u_int32_t i=0 ; i<domainSize; i++) {
112- E.fr .copy (a[i], E.fr .zero ());
113- E.fr .copy (b[i], E.fr .zero ());
114- }
87+ threadPool.parallelFor (0 , domainSize, [&] (int begin, int end, int numThread) {
88+ for (u_int32_t i=begin; i<end; i++) {
89+ E.fr .copy (a[i], E.fr .zero ());
90+ E.fr .copy (b[i], E.fr .zero ());
91+ }
92+ });
11593
11694 LOG_TRACE (" Processing coefs" );
117- #ifdef _OPENMP
118- #define NLOCKS 1024
119- omp_lock_t locks[NLOCKS];
120- for (int i=0 ; i<NLOCKS; i++) omp_init_lock (&locks[i]);
121- #pragma omp parallel for
122- #endif
123- for (u_int64_t i=0 ; i<nCoefs; i++) {
124- typename Engine::FrElement *ab = (coefs[i].m == 0 ) ? a : b;
125- typename Engine::FrElement aux;
126-
127- E.fr .mul (
128- aux,
129- wtns[coefs[i].s ],
130- coefs[i].coef
131- );
132- #ifdef _OPENMP
133- omp_set_lock (&locks[coefs[i].c % NLOCKS]);
134- #endif
135- E.fr .add (
136- ab[coefs[i].c ],
137- ab[coefs[i].c ],
138- aux
139- );
140- #ifdef _OPENMP
141- omp_unset_lock (&locks[coefs[i].c % NLOCKS]);
142- #endif
143- }
144- #ifdef _OPENMP
145- for (int i=0 ; i<NLOCKS; i++) omp_destroy_lock (&locks[i]);
146- #endif
14795
96+ #define NLOCKS 1024
97+ std::vector<std::mutex> locks (NLOCKS);
98+
99+ threadPool.parallelFor (0 , nCoefs, [&] (int begin, int end, int numThread) {
100+ for (u_int64_t i=begin; i<end; i++) {
101+ typename Engine::FrElement *ab = (coefs[i].m == 0 ) ? a : b;
102+ typename Engine::FrElement aux;
103+
104+ E.fr .mul (
105+ aux,
106+ wtns[coefs[i].s ],
107+ coefs[i].coef
108+ );
109+
110+ std::lock_guard<std::mutex> guard (locks[coefs[i].c % NLOCKS]);
111+
112+ E.fr .add (
113+ ab[coefs[i].c ],
114+ ab[coefs[i].c ],
115+ aux
116+ );
117+ }
118+ });
148119 LOG_TRACE (" Calculating c" );
149- #pragma omp parallel for
150- for (u_int32_t i=0 ; i<domainSize; i++) {
151- E.fr .mul (
152- c[i],
153- a[i],
154- b[i]
155- );
156- }
120+ threadPool.parallelFor (0 , domainSize, [&] (int begin, int end, int numThread) {
121+ for (u_int64_t i=begin; i<end; i++) {
122+ E.fr .mul (
123+ c[i],
124+ a[i],
125+ b[i]
126+ );
127+ }
128+ });
157129
158130 LOG_TRACE (" Initializing fft" );
159131 u_int32_t domainPower = fft->log2 (domainSize);
@@ -164,10 +136,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
164136 LOG_DEBUG (E.fr .toString (a[0 ]).c_str ());
165137 LOG_DEBUG (E.fr .toString (a[1 ]).c_str ());
166138 LOG_TRACE (" Start Shift A" );
167- #pragma omp parallel for
168- for (u_int64_t i=0 ; i<domainSize; i++) {
169- E.fr .mul (a[i], a[i], fft->root (domainPower+1 , i));
170- }
139+
140+ threadPool.parallelFor (0 , domainSize, [&] (int begin, int end, int numThread) {
141+ for (u_int64_t i=begin; i<end; i++) {
142+ E.fr .mul (a[i], a[i], fft->root (domainPower+1 , i));
143+ }
144+ });
145+
171146 LOG_TRACE (" a After shift:" );
172147 LOG_DEBUG (E.fr .toString (a[0 ]).c_str ());
173148 LOG_DEBUG (E.fr .toString (a[1 ]).c_str ());
@@ -182,10 +157,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
182157 LOG_DEBUG (E.fr .toString (b[0 ]).c_str ());
183158 LOG_DEBUG (E.fr .toString (b[1 ]).c_str ());
184159 LOG_TRACE (" Start Shift B" );
185- #pragma omp parallel for
186- for (u_int64_t i=0 ; i<domainSize; i++) {
187- E.fr .mul (b[i], b[i], fft->root (domainPower+1 , i));
188- }
160+ threadPool.parallelFor (0 , domainSize, [&] (int begin, int end, int numThread) {
161+ for (u_int64_t i=begin; i<end; i++) {
162+ E.fr .mul (b[i], b[i], fft->root (domainPower+1 , i));
163+ }
164+ });
189165 LOG_TRACE (" b After shift:" );
190166 LOG_DEBUG (E.fr .toString (b[0 ]).c_str ());
191167 LOG_DEBUG (E.fr .toString (b[1 ]).c_str ());
@@ -201,10 +177,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
201177 LOG_DEBUG (E.fr .toString (c[0 ]).c_str ());
202178 LOG_DEBUG (E.fr .toString (c[1 ]).c_str ());
203179 LOG_TRACE (" Start Shift C" );
204- #pragma omp parallel for
205- for (u_int64_t i=0 ; i<domainSize; i++) {
206- E.fr .mul (c[i], c[i], fft->root (domainPower+1 , i));
207- }
180+ threadPool.parallelFor (0 , domainSize, [&] (int begin, int end, int numThread) {
181+ for (u_int64_t i=begin; i<end; i++) {
182+ E.fr .mul (c[i], c[i], fft->root (domainPower+1 , i));
183+ }
184+ });
208185 LOG_TRACE (" c After shift:" );
209186 LOG_DEBUG (E.fr .toString (c[0 ]).c_str ());
210187 LOG_DEBUG (E.fr .toString (c[1 ]).c_str ());
@@ -215,12 +192,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
215192 LOG_DEBUG (E.fr .toString (c[1 ]).c_str ());
216193
217194 LOG_TRACE (" Start ABC" );
218- #pragma omp parallel for
219- for (u_int64_t i=0 ; i<domainSize; i++) {
220- E.fr .mul (a[i], a[i], b[i]);
221- E.fr .sub (a[i], a[i], c[i]);
222- E.fr .fromMontgomery (a[i], a[i]);
223- }
195+ threadPool.parallelFor (0 , domainSize, [&] (int begin, int end, int numThread) {
196+ for (u_int64_t i=begin; i<end; i++) {
197+ E.fr .mul (a[i], a[i], b[i]);
198+ E.fr .sub (a[i], a[i], c[i]);
199+ E.fr .fromMontgomery (a[i], a[i]);
200+ }
201+ });
224202 LOG_TRACE (" abc:" );
225203 LOG_DEBUG (E.fr .toString (a[0 ]).c_str ());
226204 LOG_DEBUG (E.fr .toString (a[1 ]).c_str ());
@@ -230,7 +208,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
230208
231209 LOG_TRACE (" Start Multiexp H" );
232210 typename Engine::G1Point pih;
233- E.g1 .multiMulByScalar (pih, pointsH, (uint8_t *)a, sizeof (a[0 ]), domainSize);
211+ E.g1 .multiMulByScalarMSM (pih, pointsH, (uint8_t *)a, sizeof (a[0 ]), domainSize);
234212 std::ostringstream ss1;
235213 ss1 << " pih: " << E.g1 .toString (pih);
236214 LOG_DEBUG (ss1);
@@ -247,13 +225,6 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
247225 randombytes_buf ((void *)&(r.v [0 ]), sizeof (r)-1 );
248226 randombytes_buf ((void *)&(s.v [0 ]), sizeof (s)-1 );
249227
250- #ifndef USE_OPENMP
251- pA_future.get ();
252- pB1_future.get ();
253- pB2_future.get ();
254- pC_future.get ();
255- #endif
256-
257228 typename Engine::G1Point p1;
258229 typename Engine::G2Point p2;
259230
0 commit comments