/** file: hikmeans.c ** description: MEX hierarchical ikmeans. ** author: Brian Fulkerson **/ /* Copyright (C) 2007-12 Andrea Vedaldi and Brian Fulkerson. All rights reserved. This file is part of the VLFeat library and is made available under the terms of the BSD license (see the COPYING file). */ #include #include #include #include #include #include #include #include #define NFIELDS(field_names) (sizeof(field_names)/sizeof(*field_names)) enum { opt_max_niters, opt_method, opt_verbose } ; vlmxOption options [] = { {"MaxIters", 1, opt_max_niters }, {"Method", 1, opt_method }, {"Verbose", 0, opt_verbose }, {0, 0, 0 } } ; /** ------------------------------------------------------------------ ** @internal ** @brief Copy HIKM tree node to a MATLAB structure **/ static void xcreate (mxArray *mnode, int i, VlHIKMNode *node) { int node_K = vl_ikm_get_K (node->filter) ; int M = vl_ikm_get_ndims (node->filter) ; vl_ikm_acc const *centers = vl_ikm_get_centers (node->filter) ; mxArray *mcenters ; mcenters = mxCreateNumericMatrix (M, node_K, mxINT32_CLASS, mxREAL); memcpy (mxGetPr(mcenters), centers, sizeof(vl_ikm_acc) * M * node_K) ; mxSetField (mnode, i, "centers", mcenters) ; if (node->children) { mxArray * msub ; const char * field_names[] = {"centers", "sub" } ; mwSize dims [2] ; int k ; dims[0] = 1 ; dims[1] = node_K ; msub = mxCreateStructArray (2, dims, 2, field_names) ; for (k = 0 ; k < node_K ; ++k) { xcreate (msub, k, node -> children [k]) ; } mxSetField (mnode, i, "sub", msub) ; } } /** ------------------------------------------------------------------ ** @internal ** @brief Copy HIKM tree to a MATLAB structure **/ mxArray * hikm_to_matlab (VlHIKMTree * tree) { int K = vl_hikm_get_K (tree) ; int depth = vl_hikm_get_depth (tree) ; mwSize dims [2] = {1, 1} ; mxArray *mtree ; const char *field_names[] = {"K", "depth", "centers", "sub"} ; /* Create the main struct array */ mtree = mxCreateStructArray (2, dims, NFIELDS(field_names), field_names) ; mxSetField (mtree, 0, "K", mxCreateDoubleScalar (K)) ; mxSetField (mtree, 0, "depth", mxCreateDoubleScalar (depth)) ; if (tree->root) xcreate (mtree, 0, tree->root) ; return mtree; } /** ------------------------------------------------------------------ ** @internal ** @brief MEX driver **/ void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[]) { vl_uint8 *data ; enum {IN_DATA = 0, IN_K, IN_NLEAVES, IN_END} ; enum {OUT_TREE = 0, OUT_ASGN} ; int M, N, K = 2, depth = 0 ; int opt ; int next = IN_END ; mxArray const *optarg ; int nleaves = 1 ; int method_type = VL_IKM_LLOYD ; int max_niters = 200 ; int verb = 0 ; VlHIKMTree* tree ; VL_USE_MATLAB_ENV ; /* ------------------------------------------------------------------ * Check the arguments * --------------------------------------------------------------- */ if (nin < 3) { mexErrMsgTxt ("At least three arguments required."); } else if (nout > 2) { mexErrMsgTxt ("Too many output arguments."); } if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS) { mexErrMsgTxt ("DATA must be of class UINT8."); } if (! vlmxIsPlainScalar (in[IN_NLEAVES]) || (nleaves = (int) *mxGetPr (in[IN_NLEAVES])) < 1) { mexErrMsgTxt ("NLEAVES must be a scalar not smaller than 2.") ; } M = mxGetM (in[IN_DATA]); /* n of components */ N = mxGetN (in[IN_DATA]); /* n of elements */ if (! vlmxIsPlainScalar (in[IN_K]) || (K = (int) *mxGetPr (in[IN_K])) > N ) { mexErrMsgTxt ("Cannot have more clusters than data.") ; } data = (vl_uint8 *) mxGetPr (in[IN_DATA]) ; while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) { char buf [1024] ; switch (opt) { case opt_verbose : ++ verb ; break ; case opt_max_niters : if (!vlmxIsPlainScalar(optarg) || (max_niters = (int) *mxGetPr(optarg)) < 1) { mexErrMsgTxt("MaxNiters must be not smaller than 1.") ; } break ; case opt_method : if (!vlmxIsString (optarg, -1)) { mexErrMsgTxt("'Method' must be a string.") ; } if (mxGetString (optarg, buf, sizeof(buf))) { mexErrMsgTxt("Option argument too long.") ; } if (strcmp("lloyd", buf) == 0) { method_type = VL_IKM_LLOYD ; } else if (strcmp("elkan", buf) == 0) { method_type = VL_IKM_ELKAN ; } else { mexErrMsgTxt("Unknown cost type.") ; } break ; default : abort() ; break ; } } /* --------------------------------------------------------------- * Do the job * ------------------------------------------------------------ */ depth = VL_MAX(1, ceil (log (nleaves) / log(K))) ; tree = vl_hikm_new (method_type) ; if (verb) { mexPrintf("hikmeans: # dims: %d\n", M) ; mexPrintf("hikmeans: # data: %d\n", N) ; mexPrintf("hikmeans: K: %d\n", K) ; mexPrintf("hikmeans: depth: %d\n", depth) ; } vl_hikm_set_verbosity (tree, verb) ; vl_hikm_init (tree, M, K, depth) ; vl_hikm_train (tree, data, N) ; out[OUT_TREE] = hikm_to_matlab (tree) ; if (nout > 1) { vl_uint *asgn ; int j ; out [OUT_ASGN] = mxCreateNumericMatrix (vl_hikm_get_depth (tree), N, mxUINT32_CLASS, mxREAL) ; asgn = mxGetData(out[OUT_ASGN]) ; vl_hikm_push (tree, asgn, data, N) ; for (j = 0 ; j < N*depth ; ++ j) asgn [j] ++ ; } if (verb) { mexPrintf("hikmeans: done.\n") ; } /* vl_hikm_delete (tree) ; */ }