[cig-commits] r9124 - cs/benchmark/cigma/trunk/src/tests

luis at geodynamics.org luis at geodynamics.org
Wed Jan 23 18:10:29 PST 2008


Author: luis
Date: 2008-01-23 18:10:29 -0800 (Wed, 23 Jan 2008)
New Revision: 9124

Added:
   cs/benchmark/cigma/trunk/src/tests/TestAnn.cpp
Log:
Test case for ANN spatial indexing. Searches are at least 160 times faster!

Added: cs/benchmark/cigma/trunk/src/tests/TestAnn.cpp
===================================================================
--- cs/benchmark/cigma/trunk/src/tests/TestAnn.cpp	                        (rev 0)
+++ cs/benchmark/cigma/trunk/src/tests/TestAnn.cpp	2008-01-24 02:10:29 UTC (rev 9124)
@@ -0,0 +1,458 @@
+#include <iostream>
+#include <fstream>
+
+#include <cassert>
+#include <cstdlib>
+#include <ctime>
+
+#include "ANN/ANN.h"
+
+#include "MeshPart.h"
+#include "VtkUgMeshPart.h"
+#include "VtkUgReader.h"
+
+#include "Cell.h"
+#include "Tet.h"
+#include "Hex.h"
+#include "Numeric.h"
+#include "Misc.h"
+
+#define NUM_QUERY_POINTS (10*1000)
+
+using namespace std;
+using namespace cigma;
+
+
+
+class SpatialIndex
+{
+public:
+    SpatialIndex();
+    ~SpatialIndex();
+
+public:
+    void set_data(double *data, int npts, int dim);
+
+public:
+    bool find(double globalPoint[3], int *cellIndex);
+
+public:
+    int k;                  // number of nearest neighbors
+    int npts;               // number of data points
+    int dim;                // dimension of data point
+
+    double epsilon;
+
+    ANNpointArray dataPoints;
+    ANNkd_tree *kdtree;
+
+    ANNpoint queryPoint;    // query point
+    ANNidxArray nnIdx;      // near neighbor indices
+    ANNdistArray nnDists;   // near neighbor distances
+};
+
+
+
+SpatialIndex::SpatialIndex()
+{
+    npts = 0;
+    dim = 0;
+
+    dataPoints = 0;
+    kdtree = 0;
+
+    queryPoint = 0;
+
+    k = 8;
+    nnIdx = 0;
+    nnDists = 0;
+
+    epsilon = 0;
+}
+
+
+SpatialIndex::~SpatialIndex()
+{
+    if (kdtree != 0) delete kdtree;
+    if (dataPoints != 0) annDeallocPts(dataPoints);
+    if (nnIdx != 0) delete [] nnIdx;
+    if (nnDists != 0) delete [] nnDists;
+}
+
+void SpatialIndex::set_data(double *data, int npts, int dim)
+{
+    assert(npts > 0);
+    assert(dim > 0);
+
+    this->npts = npts;
+    this->dim = dim;
+
+    dataPoints = annAllocPts(npts, dim);
+    queryPoint = annAllocPt(dim);
+
+    nnIdx = new ANNidx[k];
+    nnDists = new ANNdist[k];
+
+    int i,j;
+    for (i = 0; i < npts; i++)
+    {
+        ANNpoint pt = dataPoints[i];
+        for (j = 0; j < dim; j++)
+        {
+            pt[j] = data[dim*i + j];
+        }
+    }
+
+    kdtree = new ANNkd_tree(dataPoints, npts, dim);
+}
+
+bool SpatialIndex::find(double globalPoint[3], int *cellIndex)
+{
+    queryPoint[0] = globalPoint[0];
+    queryPoint[1] = globalPoint[1];
+    queryPoint[2] = globalPoint[2];
+    
+    queryPoint[3] = globalPoint[0];
+    queryPoint[4] = globalPoint[1];
+    queryPoint[5] = globalPoint[2];
+
+    const bool verbose = false;
+    
+    if (verbose)
+    {
+        cout << "Searching for (" << globalPoint[0] << " " << globalPoint[1] << " " << globalPoint[2] << ") ";
+    }
+
+    // search
+    kdtree->annkSearch(queryPoint, k, nnIdx, nnDists, epsilon);
+
+    // assign index/indices
+    //*cellIndex = nnIdx[0];
+    for (int i = 0; i < k; i++)
+    {
+        cellIndex[i] = nnIdx[i];
+    }
+
+
+    if (verbose)
+    {
+        cout << "-> ";
+        cout << "index " << nnIdx[0] << " ";
+        cout << "with dist " << nnDists[0] << endl;
+    }
+
+    return true;
+}
+
+int main(int argc, char *argv[])
+{
+    
+    cout << "Testing ANN library..." << endl;
+
+    int e,i,j;
+    int cellIndex;
+    int output_frequency = 100;
+
+    //
+    // Timing info
+    //
+    time_t t0, t1;
+    double elapsed_mins;
+    double remaining_mins;
+    double total_mins;
+    double progress;
+    double rate;
+    double points_per_sec;
+
+    
+    // 
+    // Read mesh from file
+    //
+    string filename = "strikeslip_tet4_1000m_t0.vtk";
+    if (argc > 1)
+    {
+        filename = argv[1];
+    }
+    int nno, nsd;
+    double *coords;
+    int nel, ndofs;
+    int *connect;
+
+    VtkUgReader reader;
+    reader.open(filename);
+    reader.get_coordinates(&coords, &nno, &nsd);
+    reader.get_connectivity(&connect, &nel, &ndofs);
+    assert(nsd == 3);
+
+    MeshPart *meshPart;
+    meshPart = new VtkUgMeshPart();
+    meshPart->set_coordinates(coords, nno, nsd);
+    meshPart->set_connectivity(connect, nel, ndofs);
+    meshPart->cell = 0;
+    switch (ndofs)
+    {
+        case 4:
+            meshPart->cell = new Tet();
+            break;
+        case 8:
+            meshPart->cell = new Hex();
+            break;
+    }
+    assert(meshPart->cell != 0);
+
+    Cell *cell = meshPart->cell;
+
+    // 
+    // Get global bounding box
+    //
+    double global_minpt[3];
+    double global_maxpt[3];
+    cigma::minmax(coords, nno, nsd, global_minpt, global_maxpt);
+
+    // 
+    // Generate random query points
+    //
+    const int num_query_pts = NUM_QUERY_POINTS;
+    double *query_points = new double[num_query_pts * 3];
+    int *query_cells1 = new int[num_query_pts];
+
+    cout << "Generating " << num_query_pts << " query points" << endl << endl;
+
+    for (i = 0; i < num_query_pts; i++)
+    {
+        query_cells1[i] = 0;
+        bbox_random_point(global_minpt, global_maxpt, &query_points[3*i]);
+    }
+
+    //
+    // Brute force search -- obtain the right answers here, for
+    // double-checking the spatial indexing method
+    //
+
+    //*
+    cout << "Brute force method: " << endl;
+    cout << "point cell points/sec elapsed eta total progress" << endl;
+
+    time(&t0);
+    t1 = t0;
+    for (i = 0; i < num_query_pts; i++)
+    {
+
+        bool found = false;
+        double *point = &query_points[3*i];
+
+        if (i % output_frequency == 0)
+        {
+            cout << "(" << point[0] << " "
+                        << point[1] << " "
+                        << point[2] << ") ";
+        }
+
+        cellIndex = -1;
+        for (e = 0; e < nel; e++)
+        {
+            double uvw[3];
+
+            // update cell data
+            meshPart->get_cell_coords(e, cell->globverts);
+
+            cell->xyz2uvw(point, uvw);
+
+            found = cell->interior(uvw[0], uvw[1], uvw[2]);
+
+            if (found)
+            {
+                cellIndex = e;
+                break;
+            }
+        }
+        assert(found);
+
+        query_cells1[i] = cellIndex;
+
+        if (i % output_frequency == 0)
+        {
+            time(&t1);
+
+            elapsed_mins = (t1 - t0) / 60.0;
+            rate = elapsed_mins / (i + 1);
+            points_per_sec = (1.0/60.0) / rate;
+            remaining_mins = (num_query_pts - i) * rate;
+            total_mins = num_query_pts * rate;
+            progress = 100 * elapsed_mins / total_mins;
+
+            cout << cellIndex << " "
+                 << points_per_sec << " "
+                 << elapsed_mins << " "
+                 << remaining_mins << " "
+                 << total_mins << " "
+                 << progress << "% ";
+
+            cout << "                                 "
+                 << "                                 "
+                 << "                                 "
+                 << "                                 "
+                 << "\r";
+
+            cout << std::flush;
+        }
+    }
+    cout << endl;
+    cout << "Total " << total_mins << endl;
+    cout << endl;
+    // */
+
+
+    //
+    // Calculate bounding boxes over each element
+    //
+    int numcellboxes = nel;
+    int cellboxdim = nsd * 2;
+    double *cellboxes = new double[numcellboxes * cellboxdim];
+    for (e = 0; e < nel; e++)
+    {
+        double minpt[3], maxpt[3];
+        double *bbox = &cellboxes[cellboxdim * e];
+
+        // get cell data
+        meshPart->get_cell_coords(e, cell->globverts);
+
+        cigma::minmax(cell->globverts, cell->n_nodes(), cell->n_dim(), minpt, maxpt);
+        
+        bbox[0] = minpt[0];
+        bbox[1] = minpt[1];
+        bbox[2] = minpt[2];
+
+        bbox[3] = maxpt[0];
+        bbox[4] = maxpt[1];
+        bbox[5] = maxpt[2];
+    }
+
+
+    // 
+    // Search mesh using a spatial index
+    //
+
+    SpatialIndex *locator = new SpatialIndex();
+
+    locator->set_data(cellboxes, numcellboxes, cellboxdim);
+    delete [] cellboxes;
+
+    int num_candidates = locator->k;
+    int *query_cells2 = new int[num_query_pts * num_candidates];
+    int *cellIndices = new int[num_candidates];
+
+    cout << "Using kdtree spatial index...(searching for " << num_candidates << " neighbors)" << endl;
+    cout << "point cell points/sec elapsed eta total progress" << endl;
+
+    time(&t0);
+    t1 = t0;
+    for (i = 0; i < num_query_pts; i++)
+    {
+        double *point = &query_points[3*i];
+
+        if (i % output_frequency == 0)
+        {
+            cout << "(" << point[0] << " "
+                        << point[1] << " "
+                        << point[2] << ") ";
+        }
+
+
+        //cout << "fubar'd -> " << i << endl;
+        locator->find(point, cellIndices);
+
+        for (j = 0; j < num_candidates; j++)
+        {
+            query_cells2[num_candidates * i + j] = cellIndices[j];
+        }
+
+
+        if (i % output_frequency == 0)
+        {
+            time(&t1);
+
+            elapsed_mins = (t1 - t0) / 60.0;
+            rate = elapsed_mins / (i + 1);
+            points_per_sec = (1.0/60.0) / rate;
+            remaining_mins = (num_query_pts - i) * rate;
+            total_mins = num_query_pts * rate;
+            progress = 100 * elapsed_mins / total_mins;
+
+            cout << cellIndices[0] << " "
+                 << points_per_sec << " "
+                 << elapsed_mins << " "
+                 << remaining_mins << " "
+                 << total_mins << " "
+                 << progress << "% ";
+
+            cout << "                                 "
+                 << "                                 "
+                 << "                                 "
+                 << "                                 "
+                 << "\r";
+
+            cout << std::flush;
+        }
+    }
+
+    delete [] cellIndices;
+
+    cout << endl;
+    cout << "Total " << total_mins << endl;
+    cout << endl;
+
+    
+    // 
+    // Write out indices
+    //
+    ofstream indexfile;
+    indexfile.open("foo.index");
+    for (i = 0; i < num_query_pts; i++)
+    {
+        double *point = &query_points[3*i];
+
+        indexfile << setw(8) << point[0] << " "
+                  << setw(8) << point[1] << " "
+                  << setw(8) << point[2] << " "
+                  << setw(8) << query_cells1[i] << " ";
+
+        for (j = 0; j < num_candidates; j++)
+        {
+            indexfile << setw(8) << query_cells2[num_candidates*i + j] << " ";
+        }
+
+        bool found = false;
+        for (j = 0; j < num_candidates; j++)
+        {
+            int a = query_cells1[i];
+            int b = query_cells2[num_candidates * i + j];
+            if (a == b)
+            {
+                found = true;
+            }
+        }
+        if (found)
+        {
+            indexfile << "yes";
+        }
+        else
+        {
+            indexfile << "no!";
+        }
+        indexfile << endl;
+    }
+    indexfile.close();
+
+
+    //
+    // Clean up
+    //
+    delete locator;
+    delete [] query_points;
+    delete [] query_cells1;
+    delete [] query_cells2;
+
+
+    return 0;
+}



More information about the cig-commits mailing list