/** @file vl_svmtrain.c ** @brief vl_svmtrain MEX definition ** @author Daniele Perrone ** @author Milan Sulc ** @author Andrea Vedaldi **/ /* Copyright (C) 2012 Daniele Perrone. Copyright (C) 2013 Milan Sulc Copyright (C) 2011-13 Andrea Vedaldi. All rights reserved. This file is part of the VLFeat library and is made available under the terms of the BSD license (see the COPYING file). */ #include #include #include #include #include #include #include /* option codes */ enum { // common opt_epsilon, opt_max_num_iterations, opt_bias_multiplier, opt_diagnostic_function, opt_diagnostic_frequency, opt_validation_subset, opt_loss, opt_model, opt_bias, opt_weights, // switching to SDCA opt_verbose, opt_solver, // SGD specific opt_starting_iteration, opt_bias_learning_rate // DCA specific: }; /* options */ vlmxOption options [] = { {"Epsilon", 1, opt_epsilon }, {"MaxNumIterations", 1, opt_max_num_iterations }, {"BiasMultiplier", 1, opt_bias_multiplier }, {"DiagnosticFunction", 1, opt_diagnostic_function }, {"DiagnosticFrequency", 1, opt_diagnostic_frequency}, {"ValidationSubset", 1, opt_validation_subset }, {"Loss", 1, opt_loss }, {"Verbose", 0, opt_verbose }, {"Solver", 1, opt_solver }, {"Model", 1, opt_model }, {"Bias", 1, opt_bias }, {"Weights", 1, opt_weights }, // SGD specific {"StartingIteration", 1, opt_starting_iteration }, {"BiasLearningRate", 1, opt_bias_learning_rate }, // DCA specific: {0, 0, 0 } } ; mxArray * createScalarStructArray(void const **fields) { void const **iter ; char const **niter ; char const **names ; vl_size numFields = 0 ; mxArray * s ; mwSize dims [] = {1, 1} ; for (iter = fields ; *iter ; iter += 2) numFields++ ; names = vl_calloc(numFields, sizeof(char const*)) ; for (iter = fields, niter = names ; *iter ; iter += 2, niter++) { *niter = *iter ; } s = mxCreateStructArray(sizeof(dims)/sizeof(dims[0]), dims, (int)numFields, names) ; for (iter = fields, niter = names ; *iter; iter += 2, niter++) { mxSetField(s, 0, *niter, (mxArray*)(*(iter+1))) ; } return s ; } /* ---------------------------------------------------------------- */ /* Parsing datasets */ /* ---------------------------------------------------------------- */ VlSvmDataset * parseDataset(const mxArray * dataset_array) { VlSvmDataset * dataset ; { mxArray * data_array ; mxClassID dataClass ; vl_size dimension ; vl_size numData ; vl_type dataType ; if (! mxIsStruct(dataset_array)) { vlmxError(vlmxErrInvalidArgument, "DATASET is not a structure.") ; } if (mxGetNumberOfElements(dataset_array) != 1) { vlmxError(vlmxErrInvalidArgument, "DATASET is not a singleton.") ; } data_array = mxGetField(dataset_array, 0, "data") ; if (data_array == NULL) { vlmxError(vlmxErrInvalidArgument, "DATASET is missing the DATA field.") ; } if (!vlmxIsMatrix(data_array,-1,-1)) { vlmxError(vlmxErrInvalidArgument,"DATASET.DATA is not a matrix.") ; } dimension = mxGetM (data_array) ; numData = mxGetN (data_array) ; dataClass = mxGetClassID (data_array) ; if (dimension == 0 || numData == 0) { vlmxError(vlmxErrInvalidArgument, "DATASET.DATA is empty.") ; } switch (dataClass) { case mxSINGLE_CLASS : dataType = VL_TYPE_FLOAT ; break ; case mxDOUBLE_CLASS : dataType = VL_TYPE_DOUBLE ; break ; default: vlmxError(vlmxErrInvalidArgument, "DATASET.DATA is neither either SINGLE or DOUBLE.") ; } dataset = vl_svmdataset_new(dataType, mxGetData(data_array), dimension, numData) ; } /* homogeneous kernel map support */ { VlHomogeneousKernelType kernelType = VlHomogeneousKernelChi2 ; VlHomogeneousKernelMapWindowType windowType = VlHomogeneousKernelMapWindowRectangular ; double gamma = 1.0 ; double period = -1 ; int n = 1 ; VlHomogeneousKernelMap * hom = NULL ; mxArray * hom_array ; mxArray * field ; hom_array = mxGetField(dataset_array, 0, "homkermap") ; if (hom_array != NULL) { if (!mxIsStruct(hom_array)) { vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP is not a structure") ; } field = mxGetField(hom_array, 0, "order") ; if (field != NULL) { if (! vlmxIsPlainScalar(field)) { vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.ORDER is not a scalar.") ; } n = *mxGetPr(field) ; if (n < 0) { vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.ORDER is negative.") ; } } field = mxGetField(hom_array, 0, "kernel") ; if (field != NULL) { char buffer [1024] ; mxGetString(field, buffer, sizeof(buffer) / sizeof(char)) ; if (vl_string_casei_cmp("kl1", buffer) == 0) { kernelType = VlHomogeneousKernelIntersection ; } else if (vl_string_casei_cmp("kchi2", buffer) == 0) { kernelType = VlHomogeneousKernelChi2 ; } else if (vl_string_casei_cmp("kjs", buffer) == 0) { kernelType = VlHomogeneousKernelJS ; } else if (vl_string_casei_cmp("kinters", buffer) == 0) { kernelType = VlHomogeneousKernelIntersection ; } else { vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.KERNEL is not a recognized kernel type.") ; } } field = mxGetField(hom_array, 0, "window") ; if (field != NULL) { char buffer [1024] ; mxGetString(field, buffer, sizeof(buffer) / sizeof(char)) ; if (vl_string_casei_cmp("uniform", buffer) == 0) { windowType = VlHomogeneousKernelMapWindowUniform ; } else if (vl_string_casei_cmp("rectangular", buffer) == 0) { windowType = VlHomogeneousKernelMapWindowRectangular; } else { vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.WINDOW is not a recognized window type.") ; } } field = mxGetField(hom_array, 0, "gamma") ; if (field != NULL) { if (! vlmxIsPlainScalar(field)) { vlmxError(vlmxErrInvalidArgument, "GAMMA is not a scalar.") ; } gamma = *mxGetPr(field) ; if (gamma <= 0) { vlmxError(vlmxErrInvalidArgument, "GAMMA is not positive.") ; } } field = mxGetField(hom_array, 0, "period") ; if (field != NULL) { if (! vlmxIsPlainScalar(field)) { vlmxError(vlmxErrInvalidArgument, "PERIOD is not a scalar.") ; } period = *mxGetPr(field) ; if (period <= 0) { vlmxError(vlmxErrInvalidArgument, "PERIOD is not positive.") ; } } hom = vl_homogeneouskernelmap_new (kernelType, gamma, n, period, windowType) ; vl_svmdataset_set_homogeneous_kernel_map (dataset, hom) ; } } return dataset ; } /* ---------------------------------------------------------------- */ /* Diagnostic helpers */ /* ---------------------------------------------------------------- */ mxArray * makeInfoStruct (VlSvm* svm) { VlSvmStatistics const * s = vl_svm_get_statistics(svm) ; mxArray * info = 0 ; switch (vl_svm_get_solver(svm)) { case VlSvmSolverSdca: { void const * fields [] = { "solver", mxCreateString("sdca"), "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)), "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)), "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)), "objective", vlmxCreatePlainScalar(s->objective), "regularizer", vlmxCreatePlainScalar(s->regularizer), "loss", vlmxCreatePlainScalar(s->loss), "dualObjective", vlmxCreatePlainScalar(s->dualObjective), "dualLoss", vlmxCreatePlainScalar(s->dualLoss), "dualityGap", vlmxCreatePlainScalar(s->dualityGap), "iteration", vlmxCreatePlainScalar(s->iteration), "epoch", vlmxCreatePlainScalar(s->epoch), "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime), 0, 0 } ; info = createScalarStructArray(fields) ; break ; } case VlSvmSolverSgd: { void const * fields [] = { "solver", mxCreateString("sgd"), "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)), "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)), "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)), "objective", vlmxCreatePlainScalar(s->objective), "regularizer", vlmxCreatePlainScalar(s->regularizer), "loss", vlmxCreatePlainScalar(s->loss), "scoreVariation", vlmxCreatePlainScalar(s->scoresVariation), "iteration", vlmxCreatePlainScalar(s->iteration), "epoch", vlmxCreatePlainScalar(s->epoch), "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime), 0, 0 } ; info = createScalarStructArray(fields) ; break ; } case VlSvmSolverNone : { void const * fields [] = { "solver", mxCreateString("none"), "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)), "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)), "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)), "objective", vlmxCreatePlainScalar(s->objective), "regularizer", vlmxCreatePlainScalar(s->regularizer), "loss", vlmxCreatePlainScalar(s->loss), "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime), 0, 0 } ; info = createScalarStructArray(fields) ; break ; } default: assert(0) ; } return info ; } /* ---------------------------------------------------------------- */ /* SVM diagnostic callback */ /* ---------------------------------------------------------------- */ typedef struct DiagnsoticOpts_ { vl_bool verbose ; mxArray const * matlabDiagonsticFunctionHandle ; } DiagnosticOpts ; void diagnostic (VlSvm * svm, DiagnosticOpts * opts) { VlSvmStatistics const * s = vl_svm_get_statistics(svm) ; if ((opts->verbose && s->status != VlSvmStatusTraining) || (opts->verbose > 1)) { const char * statusName = 0 ; switch (s->status) { case VlSvmStatusTraining: statusName = "training" ; break ; case VlSvmStatusConverged: statusName = "converged" ; break ; case VlSvmStatusMaxNumIterationsReached: statusName = "max num iterations reached" ; break ; } mexPrintf("vl_svmtrain: iteration: %d (epoch: %d)\n", s->iteration+1, s->epoch+1) ; mexPrintf("\ttime elapsed: %f\n", s->elapsedTime) ; mexPrintf("\tobjective: %g (regul: %g, loss: %g)\n", s->objective, s->regularizer, s->loss) ; switch (vl_svm_get_solver(svm)) { case VlSvmSolverSgd: mexPrintf("\tscore variation: %f\n", s->scoresVariation) ; break; case VlSvmSolverSdca: mexPrintf("\tdual objective: %g (dual loss: %g)\n", s->dualObjective, s->dualLoss) ; mexPrintf("\tduality gap: %g\n", s->dualityGap) ; break; default: break; } mexPrintf("\tstatus: %s\n", statusName) ; } if (opts->matlabDiagonsticFunctionHandle) { mxArray *rhs[2] ; rhs[0] = (mxArray*) opts->matlabDiagonsticFunctionHandle ; rhs[1] = makeInfoStruct(svm) ; if (mxIsClass(rhs[0] , "function_handle")) { mexCallMATLAB(0,NULL,sizeof(rhs)/sizeof(rhs[0]),rhs,"feval") ; } mxDestroyArray(rhs[1]) ; } } /* ---------------------------------------------------------------- */ /* MEX entry point */ /* ---------------------------------------------------------------- */ void mexFunction(int nout, mxArray *out[], int nin, const mxArray *in[]) { enum {IN_DATASET = 0, IN_LABELS, IN_LAMBDA, IN_END} ; enum {OUT_MODEL = 0, OUT_BIAS, OUT_INFO, OUT_SCORES, OUT_END} ; vl_int opt, next; mxArray const *optarg ; VlSvmSolverType solver = VlSvmSolverSdca ; VlSvmLossType loss = VlSvmLossHinge ; int verbose = 0 ; VlSvmDataset * dataset ; double * labels ; double * weights = NULL ; double lambda ; double epsilon = -1 ; double biasMultipler = -1 ; vl_index maxNumIterations = -1 ; vl_index diagnosticFrequency = -1 ; mxArray const * matlabDiagnosticFunctionHandle = NULL ; mxArray const * initialModel_array = NULL ; double initialBias = VL_NAN_D ; vl_index startingIteration = -1 ; /* SGD */ double sgdBiasLearningRate = -1 ; VL_USE_MATLAB_ENV ; if (nin < 3) { vlmxError(vlmxErrInvalidArgument, "At least three arguments are required.") ; } if (nout > OUT_END) { vlmxError(vlmxErrInvalidArgument, "Too many output arguments."); } #define GET_SCALAR(NAME, variable) \ if (!vlmxIsPlainScalar(optarg)) { \ vlmxError(vlmxErrInvalidArgument, VL_STRINGIFY(NAME) " is not a plain scalar.") ; \ } \ variable = (double) *mxGetPr(optarg); #define GET_NN_SCALAR(NAME, variable) GET_SCALAR(NAME, variable) \ if (variable < 0) { \ vlmxError(vlmxErrInvalidArgument, VL_STRINGIFY(NAME) " is negative.") ; \ } /* Mode 1: pass data, labels, lambda, and options */ if (mxIsNumeric(in[IN_DATASET])) { mxArray const* samples_array = in[IN_DATASET] ; vl_size dimension ; vl_size numSamples ; void * data ; vl_type dataType ; if (!vlmxIsMatrix(samples_array, -1, -1)) { vlmxError (vlmxErrInvalidArgument, "X is not a matrix.") ; } if (mxGetClassID(samples_array) == mxDOUBLE_CLASS) { dataType = VL_TYPE_DOUBLE ; } else if (mxGetClassID(samples_array) == mxSINGLE_CLASS) { dataType = VL_TYPE_FLOAT ; } else { vlmxError (vlmxErrInvalidArgument, "X is not of class SINGLE or DOUBLE.") ; } data = mxGetData(samples_array) ; dimension = mxGetM(samples_array) ; numSamples = mxGetN(samples_array) ; dataset = vl_svmdataset_new(dataType, data, dimension, numSamples) ; } /* Mode 2: pass dataset structure */ else { dataset = parseDataset(in[IN_DATASET]) ; } { mxArray const* labels_array = in[IN_LABELS] ; if (!vlmxIsPlainMatrix(labels_array, -1, -1)) { vlmxError (vlmxErrInvalidArgument, "Y is not a plain matrix.") ; } labels = mxGetPr(labels_array) ; if (mxGetNumberOfElements(labels_array) != vl_svmdataset_get_num_data(dataset)) { vlmxError (vlmxErrInvalidArgument, "The number of labels Y is not the same as the number of data samples X.") ; } optarg = in[IN_LAMBDA] ; GET_NN_SCALAR(LAMBDA, lambda) ; } /* Parse optional arguments */ next = 3 ; while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) { char buf [1024] ; switch (opt) { case opt_verbose: verbose ++ ; break ; case opt_epsilon: GET_NN_SCALAR(EPSLON, epsilon) ; break ; case opt_bias_multiplier: GET_NN_SCALAR(BIASMULTIPLIER, biasMultipler) ; break ; case opt_max_num_iterations: GET_NN_SCALAR(MAXNUMITERATIONS, maxNumIterations) ; break ; case opt_diagnostic_frequency: GET_NN_SCALAR(DIAGNOSTICFREQUENCY, diagnosticFrequency) ; break ; case opt_diagnostic_function: if (!mxIsClass(optarg ,"function_handle")) { mexErrMsgTxt("DIAGNOSTICSFUNCTION is not a function handle."); } matlabDiagnosticFunctionHandle = optarg ; break ; case opt_solver : if (!vlmxIsString (optarg, -1)) { vlmxError (vlmxErrInvalidArgument, "SOLVER must be a string.") ; } if (mxGetString (optarg, buf, sizeof(buf))) { vlmxError (vlmxErrInvalidArgument, "SOLVER argument too long.") ; } if (vlmxCompareStringsI("sgd", buf) == 0) { solver = VlSvmSolverSgd ; } else if (vlmxCompareStringsI("sdca", buf) == 0) { solver = VlSvmSolverSdca ; } else if (vlmxCompareStringsI("none", buf) == 0) { solver = VlSvmSolverNone; } else { vlmxError (vlmxErrInvalidArgument, "Invalid value %s for SOLVER", buf) ; } break ; case opt_loss : if (!vlmxIsString (optarg, -1)) { vlmxError (vlmxErrInvalidArgument, "LOSS must be a string.") ; } if (mxGetString (optarg, buf, sizeof(buf))) { vlmxError (vlmxErrInvalidArgument, "LOSS argument too long.") ; } if (vlmxCompareStringsI("hinge", buf) == 0) { loss = VlSvmLossHinge ; } else if (vlmxCompareStringsI("hinge2", buf) == 0) { loss = VlSvmLossHinge2 ; } else if (vlmxCompareStringsI("l1", buf) == 0) { loss = VlSvmLossL1 ; } else if (vlmxCompareStringsI("l2", buf) == 0) { loss = VlSvmLossL2 ; } else if (vlmxCompareStringsI("logistic", buf) == 0) { loss = VlSvmLossLogistic ; } else { vlmxError (vlmxErrInvalidArgument, "Invalid value %s for LOSS", buf) ; } break ; case opt_model : if (!vlmxIsPlainVector(optarg, vl_svmdataset_get_dimension(dataset))) { vlmxError(vlmxErrInvalidArgument, "MODEL is not a plain vector of size equal to the data dimension.") ; } initialModel_array = optarg ; break ; case opt_bias: GET_SCALAR(BIAS, initialBias) ; break ; case opt_weights: if (!vlmxIsPlainVector(optarg, vl_svmdataset_get_num_data(dataset))) { vlmxError(vlmxErrInvalidArgument, "WEIGHTS is not a plain vector of size equal to the number of training samples.") ; } weights = mxGetPr(optarg) ; break ; /* SGD specific */ case opt_starting_iteration: GET_NN_SCALAR(STARTINGITERATION, startingIteration) ; break ; case opt_bias_learning_rate: GET_NN_SCALAR(BIASLEARNINGRATE, sgdBiasLearningRate) ; break ; /* DCA specific */ } /* choose option */ } /* next option */ { VlSvm * svm = vl_svm_new_with_dataset(solver, dataset, labels, lambda) ; DiagnosticOpts dopts ; if (initialModel_array) { if (solver != VlSvmSolverNone && solver != VlSvmSolverSgd) { vlmxError(vlmxErrInvalidArgument, "MODEL cannot be specified with this type of solver.") ; } if (mxGetNumberOfElements(initialModel_array) != vl_svm_get_dimension(svm)) { vlmxError(vlmxErrInvalidArgument, "MODEL has not the same dimension as the data.") ; } vl_svm_set_model(svm, mxGetPr(initialModel_array)) ; } if (! vl_is_nan_d(initialBias)) { if (solver != VlSvmSolverNone && solver != VlSvmSolverSgd) { vlmxError(vlmxErrInvalidArgument, "BIAS cannot be specified with this type of solver.") ; } vl_svm_set_bias(svm, initialBias) ; } if (epsilon >= 0) vl_svm_set_epsilon(svm, epsilon) ; if (maxNumIterations >= 0) vl_svm_set_max_num_iterations(svm, maxNumIterations) ; if (biasMultipler >= 0) vl_svm_set_bias_multiplier(svm, biasMultipler) ; if (sgdBiasLearningRate >= 0) vl_svm_set_bias_learning_rate(svm, sgdBiasLearningRate) ; if (diagnosticFrequency >= 0) vl_svm_set_diagnostic_frequency(svm, diagnosticFrequency) ; if (startingIteration >= 0) vl_svm_set_iteration_number(svm, (unsigned)startingIteration) ; if (weights) vl_svm_set_weights(svm, weights) ; vl_svm_set_loss (svm, loss) ; dopts.verbose = verbose ; dopts.matlabDiagonsticFunctionHandle = matlabDiagnosticFunctionHandle ; vl_svm_set_diagnostic_function (svm, (VlSvmDiagnosticFunction)diagnostic, &dopts) ; if (verbose) { double C = 1.0 / (vl_svm_get_lambda(svm) * vl_svm_get_num_data (svm)) ; char const * lossName = 0 ; switch (loss) { case VlSvmLossHinge: lossName = "hinge" ; break ; case VlSvmLossHinge2: lossName = "hinge2" ; break ; case VlSvmLossL1: lossName = "l1" ; break ; case VlSvmLossL2: lossName = "l2" ; break ; case VlSvmLossLogistic: lossName = "logistic" ; break ; } mexPrintf("vl_svmtrain: parameters (verbosity: %d)\n", verbose) ; mexPrintf("\tdata dimension: %d\n",vl_svmdataset_get_dimension(dataset)) ; mexPrintf("\tnum samples: %d\n", vl_svmdataset_get_num_data(dataset)) ; mexPrintf("\tlambda: %g (C equivalent: %g)\n", vl_svm_get_lambda(svm), C) ; mexPrintf("\tloss function: %s\n", lossName) ; mexPrintf("\tmax num iterations: %d\n", vl_svm_get_max_num_iterations(svm)) ; mexPrintf("\tepsilon: %g\n", vl_svm_get_epsilon(svm)) ; mexPrintf("\tdiagnostic frequency: %d\n", vl_svm_get_diagnostic_frequency(svm)) ; mexPrintf("\tusing custom weights: %s\n", VL_YESNO(weights)) ; mexPrintf("\tbias multiplier: %g\n", vl_svm_get_bias_multiplier(svm)) ; switch (vl_svm_get_solver(svm)) { case VlSvmSolverNone: mexPrintf("\tsolver: none (evaluation mode)\n") ; break ; case VlSvmSolverSgd: mexPrintf("\tsolver: sgd\n") ; mexPrintf("\tbias learning rate: %g\n", vl_svm_get_bias_learning_rate(svm)) ; break ; case VlSvmSolverSdca: mexPrintf("\tsolver: sdca\n") ; break ; } } vl_svm_train(svm) ; { mwSize dims[2] ; dims[0] = vl_svmdataset_get_dimension(dataset) ; dims[1] = 1 ; out[OUT_MODEL] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL) ; memcpy(mxGetPr(out[OUT_MODEL]), vl_svm_get_model(svm), vl_svm_get_dimension(svm) * sizeof(double)) ; } out[OUT_BIAS] = vlmxCreatePlainScalar(vl_svm_get_bias(svm)) ; if (nout >= 3) { out[OUT_INFO] = makeInfoStruct(svm) ; } if (nout >= 4) { mwSize dims[2] ; dims[0] = 1 ; dims[1] = vl_svmdataset_get_num_data(dataset) ; out[OUT_SCORES] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL) ; memcpy(mxGetPr(out[OUT_SCORES]), vl_svm_get_scores(svm), vl_svm_get_num_data(svm) * sizeof(double)) ; } vl_svm_delete(svm) ; if (vl_svmdataset_get_homogeneous_kernel_map(dataset)) { VlHomogeneousKernelMap * hom = vl_svmdataset_get_homogeneous_kernel_map(dataset) ; vl_svmdataset_set_homogeneous_kernel_map(dataset,0) ; vl_homogeneouskernelmap_delete(hom) ; } vl_svmdataset_delete(dataset) ; } }