-
Notifications
You must be signed in to change notification settings - Fork 224
Expand file tree
/
Copy pathsvm_model.h
More file actions
242 lines (218 loc) · 8.71 KB
/
svm_model.h
File metadata and controls
242 lines (218 loc) · 8.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
/* file: svm_model.h */
/*******************************************************************************
* Copyright 2014 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/*
//++
// Implementation of the class defining the SVM model.
//--
*/
#ifndef __SVM_MODEL_H__
#define __SVM_MODEL_H__
#include "data_management/data/homogen_numeric_table.h"
#include "data_management/data/csr_numeric_table.h"
#include "algorithms/model.h"
#include "algorithms/kernel_function/kernel_function.h"
#include "algorithms/classifier/classifier_model.h"
namespace daal
{
namespace algorithms
{
/**
* @defgroup svm Support Vector Machine Classifier
* \copydoc daal::algorithms::svm
* @ingroup classification
* @{
*/
/**
* \brief Contains classes to work with the support vector machine classifier
*/
namespace svm
{
/**
* \brief Contains version 2.0 of oneAPI Data Analytics Library interface.
*/
namespace interface2
{
/**
* @ingroup svm
* @{
*/
/**
* <a name="DAAL-STRUCT-ALGORITHMS__SVM__PARAMETER"></a>
* \brief Optional parameters
*
* \snippet svm/svm_model.h Parameter source code
*/
/* [Parameter source code] */
struct DAAL_EXPORT Parameter : public classifier::Parameter
{
Parameter(const kernel_function::KernelIfacePtr & kernelForParameter =
kernel_function::createKernelFunction(kernel_function::linearKernel),
double C = 1.0, double accuracyThreshold = 0.001, double tau = 1.0e-6, size_t maxIterations = 1000000, size_t cacheSize = 8000000,
bool doShrinking = true, size_t shrinkingStep = 1000)
: C(C),
accuracyThreshold(accuracyThreshold),
tau(tau),
maxIterations(maxIterations),
cacheSize(cacheSize),
doShrinking(doShrinking),
shrinkingStep(shrinkingStep),
kernel(kernelForParameter) {};
double C; /*!< Upper bound in constraints of the quadratic optimization problem */
double accuracyThreshold; /*!< Training accuracy */
double tau; /*!< Tau parameter of the working set selection scheme */
size_t maxIterations; /*!< Maximal number of iterations for the algorithm */
size_t cacheSize; /*!< Size of cache in bytes to store values of the kernel matrix.
A non-zero value enables use of a cache optimization technique */
bool doShrinking; /*!< Flag that enables use of the shrinking optimization technique */
size_t shrinkingStep; /*!< Number of iterations between the steps of shrinking optimization technique */
algorithms::kernel_function::KernelIfacePtr kernel; /*!< Kernel function */
services::Status check() const override;
};
/* [Parameter source code] */
} // namespace interface2
namespace interface1
{
/**
* <a name="DAAL-CLASS-ALGORITHMS__SVM__MODEL"></a>
* \brief %Model of the classifier trained by the svm::training::Batch algorithm
*
* \par References
* - Parameter class
* - \ref training::interface2::Batch "training::Batch" class
* - \ref prediction::interface2::Batch "prediction::Batch" class
*/
class DAAL_EXPORT Model : public classifier::Model
{
public:
DECLARE_MODEL(Model, classifier::Model);
/**
* Constructs the SVM model
* \tparam modelFPType Data type to store SVM model data, double or float
* \param[in] dummy Dummy variable for the templated constructor
* \param[in] nColumns Number of features in input data
* \param[in] layout Data layout of the numeric table of support vectors
* \DAAL_DEPRECATED_USE{ Model::create }
*/
template <typename modelFPType>
Model(modelFPType dummy, size_t nColumns, data_management::NumericTableIface::StorageLayout layout = data_management::NumericTableIface::aos)
: _bias(0.0)
{
using namespace data_management;
if (layout == NumericTableIface::csrArray)
{
modelFPType * dummyPtr = NULL;
_SV = CSRNumericTable::create(dummyPtr, NULL, NULL, nColumns);
}
else
{
_SV = HomogenNumericTable<modelFPType>::create(NULL, nColumns, 0);
}
_SVCoeff = HomogenNumericTable<modelFPType>::create(NULL, 1, 0);
_SVIndices = HomogenNumericTable<int>::create(NULL, 1, 0);
}
/**
* Constructs the SVM model
* \tparam modelFPType Data type to store SVM model data, double or float
* \param[in] nColumns Number of features in input data
* \param[in] layout Data layout of the numeric table of support vectors
* \param[out] stat Status of the model construction
* \return SVM model
*/
template <typename modelFPType>
DAAL_EXPORT static services::SharedPtr<Model> create(
size_t nColumns, data_management::NumericTableIface::StorageLayout layout = data_management::NumericTableIface::aos,
services::Status * stat = NULL);
/**
* Empty constructor for deserialization
* \DAAL_DEPRECATED_USE{ Model::create }
*/
Model() : _SV(), _SVCoeff(), _bias(0.0), _SVIndices() {}
/**
* Constructs empty SVM model for deserialization
* \param[out] stat Status of the model construction
* \return Empty SVM model for deserialization
*/
static services::SharedPtr<Model> create(services::Status * stat = NULL)
{
services::SharedPtr<Model> modelPtr(new Model());
if (!modelPtr)
{
if (stat) stat->add(services::ErrorMemoryAllocationFailed);
}
return modelPtr;
}
virtual ~Model() {}
/**
* Returns support vectors constructed during the training of the SVM model
* \return Array of support vectors
*/
data_management::NumericTablePtr getSupportVectors() { return _SV; }
/**
* Returns indices of the support vectors constructed during the training of the SVM model
* \return Array of support vectors indices
*/
data_management::NumericTablePtr getSupportIndices() { return _SVIndices; }
/**
* Returns classification coefficients constructed during the training of the SVM model
* \return Array of classification coefficients
*/
data_management::NumericTablePtr getClassificationCoefficients() { return _SVCoeff; }
/**
* Returns the bias constructed during the training of the SVM model
* \return Bias
*/
virtual double getBias() { return _bias; }
/**
* Sets the bias for the SVM model
* \param bias Bias of the model
*/
virtual void setBias(double bias) { _bias = bias; }
/**
* Retrieves the number of features in the dataset was used on the training stage
* \return Number of features in the dataset was used on the training stage
*/
size_t getNumberOfFeatures() const override { return (_SV ? _SV->getNumberOfColumns() : 0); }
protected:
data_management::NumericTablePtr _SV; /*!< \private Support vectors */
data_management::NumericTablePtr _SVCoeff; /*!< \private Classification coefficients */
double _bias; /*!< \private Bias of the distance function D(x) = w*Phi(x) + bias */
data_management::NumericTablePtr _SVIndices; /*!< \private Indices of the support vectors in training data set */
template <typename modelFPType>
DAAL_EXPORT Model(modelFPType dummy, size_t nColumns, data_management::NumericTableIface::StorageLayout layout, services::Status & st);
template <typename Archive, bool onDeserialize>
services::Status serialImpl(Archive * arch)
{
services::Status st = classifier::Model::serialImpl<Archive, onDeserialize>(arch);
if (!st) return st;
arch->setSharedPtrObj(_SV);
arch->setSharedPtrObj(_SVCoeff);
arch->set(_bias);
arch->setSharedPtrObj(_SVIndices);
return st;
}
};
typedef services::SharedPtr<Model> ModelPtr;
/** @} */
} // namespace interface1
using interface2::Parameter;
using interface1::Model;
using interface1::ModelPtr;
} // namespace svm
/** @} */
} // namespace algorithms
} // namespace daal
#endif