From 9b2dabe48743fc19f4cf42d89fd249eca861de8c Mon Sep 17 00:00:00 2001
From: Romain Boman <r.boman@uliege.be>
Date: Mon, 19 Apr 2021 16:56:38 +0200
Subject: [PATCH] basic example of "sweep and prune"

---
 models/boneload.py | 437 ++++++++++++++++++++++++++++-----------------
 1 file changed, 274 insertions(+), 163 deletions(-)

diff --git a/models/boneload.py b/models/boneload.py
index 4de3fc5..e688ac6 100644
--- a/models/boneload.py
+++ b/models/boneload.py
@@ -5,19 +5,21 @@
 #   a model of muscle forces exerted on curvilinear bone structures
 #
 # from Ian R. Grosse, Elizabeth R. Dumont, Chris Coletta, Alex Tolleson
-#   "Techniques for Modeling Muscle-Induced Forces in Finite Element 
+#   "Techniques for Modeling Muscle-Induced Forces in Finite Element
 #   Models of Skeletal Structures"
 #   THE ANATOMICAL RECORD 290 pp.1069–1088
 #   https://doi.org/10.1002/ar.20568
 #
 # requires VTK and numpy
 
-import os, math
+import os
+import math
 import vtk
 colors = vtk.vtkNamedColors()
 # import numpy
 # from vtk.util.numpy_support import vtk_to_numpy
 
+
 def load_msh(meshfile):
     """loads a mesh in .off or .ply format
     returns: 
@@ -25,7 +27,7 @@ def load_msh(meshfile):
         a list of triangles defined by 3 node tags
     """
     ext = os.path.splitext(meshfile)[1].lower()
-    if ext=='.off':
+    if ext == '.off':
         return load_off(meshfile)
     elif ext in ['.ply', '.stl']:
         return load_msh_with_vtk(meshfile)
@@ -35,6 +37,8 @@ def load_msh(meshfile):
 
 def load_msh_with_vtk(meshfile):
     """loads a file in "Stanford University PLY" (.ply) or STL (.stl) using VTK
+    (this routine could be easily extended to any polydata format 
+    supported by VTK such as vtkOBJReader)
     returns: 
         a list of node coordinates 
         a list of triangles defined by 3 node tags
@@ -44,14 +48,14 @@ def load_msh_with_vtk(meshfile):
         raise Exception(f'{meshfile} not found')
 
     ext = os.path.splitext(meshfile)[1].lower()
-    if ext=='.ply':
+    if ext == '.ply':
         reader = vtk.vtkPLYReader()
-    elif ext=='.stl':
+    elif ext == '.stl':
         reader = vtk.vtkSTLReader()
     else:
         raise Exception(f'{meshfile}: unknown mesh format!')
     reader.SetFileName(meshfile)
-    reader.Update() 
+    reader.Update()
     polydata = reader.GetOutput()
 
     npts = polydata.GetNumberOfPoints()
@@ -68,24 +72,25 @@ def load_msh_with_vtk(meshfile):
     # data = polys.GetData()
     # print(f'data={vtk_to_numpy(data)}')
 
-    # see also https://stackoverflow.com/questions/51201888/retrieving-facets-and-point-from-vtk-file-in-python    
+    # see also https://stackoverflow.com/questions/51201888/retrieving-facets-and-point-from-vtk-file-in-python
     # (faster method using numpy)
     tris = []
     idList = vtk.vtkIdList()
     polys.InitTraversal()
     while polys.GetNextCell(idList):
-        ids=[]
+        ids = []
         for i in range(0, idList.GetNumberOfIds()):
             pId = idList.GetId(i)
             ids.append(pId)
         tris.append(ids)
-    
+
     print(f'\t{len(nodes)} nodes, {len(tris)} triangles read.')
     return nodes, tris
 
 
 def load_off(meshfile):
     """loads a file in "object file format" (.off)
+    (we do it by hand... no vtkOFFReader in VTK!)
     returns: 
         a list of node coordinates 
         a list of triangles defined by 3 node tags
@@ -102,21 +107,23 @@ def load_off(meshfile):
         line = infile.readline()  # reads header
         # second line - nb of nodes, nb of triangles
         line = infile.readline().rstrip()
-        nbno, nbtri, _ = [ int(data) for data in line.split(' ') ]
+        nbno, nbtri, _ = [int(data) for data in line.split(' ')]
         # reading nodes
         print(f'\treading {nbno} nodes...')
         for i in range(nbno):
             line = infile.readline().rstrip()
-            x, y, z, *remainder = [ float(data) for data in line.split(' ') ]
-            nodes.append([x,y,z])            
+            x, y, z, *remainder = [float(data) for data in line.split(' ')]
+            nodes.append([x, y, z])
         # reading triangles
         print(f'\treading {nbtri} triangles...')
         for i in range(nbtri):
             line = infile.readline().rstrip()
-            nno, n1, n2, n3, *remainder = [ int(data) for data in line.split(' ') ]
-            if nno!=3:
-                raise Exception(f"{os.path.basename(meshfile)} contains non-triangular elements")
-            tris.append([n1,n2,n3])
+            nno, n1, n2, n3, *remainder = [int(data)
+                                           for data in line.split(' ')]
+            if nno != 3:
+                raise Exception(
+                    f"{os.path.basename(meshfile)} contains non-triangular elements")
+            tris.append([n1, n2, n3])
 
     print(f'\t{len(nodes)} nodes, {len(tris)} triangles read.')
     return nodes, tris
@@ -126,24 +133,27 @@ def identify_nodes(coords, all_no, all_coords, eps=1e-3):
     """identify the nodes based on their coordinates "coords"
     among the list of coordinates "all_coords"
     all_no, given as argument, is an array of tags for the nodes defined in all_coords
-    
+
     The algorithm used is a naive brutal approach: O( len(all_coords)*len(coords) ).
     TODO: use a faster approach
         https://youtu.be/eED4bSkYCB8
         or using VTK?
-        
+
     returns ntags, a list of node tags of the identified nodes
     """
-    notfound=[]
-    ntags=[]
-    for n in coords: # for all nodes
+    notfound = []
+    ntags = []
+    for n in coords:  # for all nodes
         found = False
         # print(f'looking for node {n} among {len(all_no), len(all_coords)} nodes...')
         for no, pos in zip(all_no, all_coords):
             # print(f'testing node {n} == {pos}?')
-            if abs(n[0]-pos[0])>eps: continue  
-            if abs(n[1]-pos[1])>eps: continue  
-            if abs(n[2]-pos[2])>eps: continue  
+            if abs(n[0]-pos[0]) > eps:
+                continue
+            if abs(n[1]-pos[1]) > eps:
+                continue
+            if abs(n[2]-pos[2]) > eps:
+                continue
             # print(f'node {n} matches with node #{no} {pos}')
             found = True
             ntags.append(no)
@@ -151,7 +161,7 @@ def identify_nodes(coords, all_no, all_coords, eps=1e-3):
         if not found:
             notfound.append(n)
 
-    if(len(notfound)!=0):
+    if(len(notfound) != 0):
         print('WARNING: these nodes have not been identified:')
         for n in notfound:
             print(f'\tn={n}')
@@ -162,41 +172,54 @@ def identify_nodes(coords, all_no, all_coords, eps=1e-3):
 class Pt:
     """a 3D point (or vector) with the usual algebraic operations
     """
+
     def __init__(self, x):
         self.x = [float(v) for v in x]
+
     def __str__(self):
         s = ','.join([str(v) for v in self.x])
         return '(' + s + ')'
+
     def __add__(self, pt):
-        return Pt([ self.x[i]+pt.x[i] for i in range(3)])
+        return Pt([self.x[i]+pt.x[i] for i in range(3)])
+
     def __sub__(self, pt):
-        return Pt([ self.x[i]-pt.x[i] for i in range(3)])
+        return Pt([self.x[i]-pt.x[i] for i in range(3)])
+
     def __neg__(self):
-        return Pt([ -self.x[i] for i in range(3)])
+        return Pt([-self.x[i] for i in range(3)])
+
     def __truediv__(self, scalar):
-        return Pt([ self.x[i]/scalar for i in range(3)])
+        return Pt([self.x[i]/scalar for i in range(3)])
+
     def __mul__(self, obj):
         if isinstance(obj, Pt):
             return self.x[0]*obj.x[0]+self.x[1]*obj.x[1]+self.x[2]*obj.x[2]
         else:
-            return Pt([ self.x[i]*obj for i in range(3)])
+            return Pt([self.x[i]*obj for i in range(3)])
+
     def __rmul__(self, scalar):
         return self.__mul__(scalar)
+
     def __abs__(self):
         return math.sqrt(self*self)
+
     def normalized(self):
         return self / abs(self)
+
     def normalize(self):
         n = self.normalized()
         self.x = n.x
+
     def cross(self, pt):
-        return Pt([self.x[1]*pt.x[2] - self.x[2]*pt.x[1], \
-                  self.x[2]*pt.x[0] - self.x[0]*pt.x[2], \
-                  self.x[0]*pt.x[1] - self.x[1]*pt.x[0]])
+        return Pt([self.x[1]*pt.x[2] - self.x[2]*pt.x[1],
+                   self.x[2]*pt.x[0] - self.x[0]*pt.x[2],
+                   self.x[0]*pt.x[1] - self.x[1]*pt.x[0]])
+
     @staticmethod
     def test():
-        a = Pt([1,0,0])
-        b = Pt([0,1,0])
+        a = Pt([1, 0, 0])
+        b = Pt([0, 1, 0])
         c = a.cross(b)
         area = abs(c)/2
         print(f'area={area}')
@@ -206,7 +229,7 @@ def scale_loads(loads, total_force):
     """scale the loads so that their sum is total_force
     """
     # compute the total force
-    totf = Pt([0,0,0])
+    totf = Pt([0, 0, 0])
     for f in loads:
         totf = totf + f
     print(f'\t|force| = {abs(totf)} (before scaling)')
@@ -218,7 +241,7 @@ def scale_loads(loads, total_force):
         loads[i] *= scalingf
 
     # verification
-    totf = Pt([0,0,0])
+    totf = Pt([0, 0, 0])
     for f in loads:
         totf = totf + f
     print(f'\ttotal_force = {totf}')
@@ -238,23 +261,23 @@ def compute_normal_force_factor(muscle, centre, nunit, targetP, debug=False, sur
         print(f'centre={centre}')
         print(f'nunit={nunit}')
         plane = Plane(centre, centre+nunit, targetP)  # only for display
-    
+
     nclip = -(targetP - (centre+(targetP*nunit)*nunit)).normalized()
     planepart = ClipPolyData(muscle.polydata, centre, nclip)
-    planecut = PlaneCut(planepart.clip.GetOutput(), centre, \
-        nunit.cross(targetP-centre))
+    planecut = PlaneCut(planepart.clip.GetOutput(), centre,
+                        nunit.cross(targetP-centre))
     closest = ClosestPart(planecut.cutter.GetOutputPort(), centre)
     pts, ids = get_segments(closest.cfilter.GetOutput())
-    
+
     # compute "s(r)"
     s = compute_length(pts, ids)
-    if debug:  
+    if debug:
         print(f's(r)={s}')
 
     # front part of the cut
     planepart2 = ClipPolyData(muscle.polydata, centre, -nclip)
-    planecut2 = PlaneCut(planepart2.clip.GetOutput(), centre, \
-        nunit.cross(targetP-centre))
+    planecut2 = PlaneCut(planepart2.clip.GetOutput(), centre,
+                         nunit.cross(targetP-centre))
     closest2 = ClosestPart(planecut2.cutter.GetOutputPort(), centre)
     pts2, ids2 = get_segments(closest2.cfilter.GetOutput())
 
@@ -274,39 +297,41 @@ def compute_normal_force_factor(muscle, centre, nunit, targetP, debug=False, sur
         # 3D debugging view
         view = View()
 
-        view.addActors(surface.actor) 
-        view.addActors(muscle.actor)  
+        view.addActors(surface.actor)
+        view.addActors(muscle.actor)
         view.addActors(muscleF.actor)
-        view.addActors(plane.actor)  
-        view.addActors(planepart.actor) 
-        view.addActors(planecut.actor) 
-        view.addActors(closest.actor) 
-        view.addActors(closest2.actor) 
+        view.addActors(plane.actor)
+        view.addActors(planepart.actor)
+        view.addActors(planecut.actor)
+        view.addActors(closest.actor)
+        view.addActors(closest2.actor)
 
         surface.actor.GetProperty().SetOpacity(0.1)
 
         muscle.mapper.SetResolveCoincidentTopologyToPolygonOffset()
-        muscle.actor.GetProperty().SetColor( colors.GetColor3d('red') )
+        muscle.actor.GetProperty().SetColor(colors.GetColor3d('red'))
         muscle.actor.GetProperty().SetOpacity(0.5)
         muscle.actor.GetProperty().EdgeVisibilityOn()
         muscle.actor.GetProperty().SetPointSize(1.0)
 
         muscleF.actor.GetProperty().SetPointSize(15.0)
-        muscleF.actor.GetProperty().SetColor( colors.GetColor3d('green') )
+        muscleF.actor.GetProperty().SetColor(colors.GetColor3d('green'))
 
-        closest2.actor.GetProperty().SetColor( colors.GetColor3d('pink') )
+        closest2.actor.GetProperty().SetColor(colors.GetColor3d('pink'))
 
         view.interact()
 
         # 2D debugging view
 
-        # project vertices of fpath onto the plane for display        
+        # project vertices of fpath onto the plane for display
         origin, xaxis, yaxis = build_local_axes(fpath, nunit)
-        xpath = [ (p-origin)*xaxis for p in reversed(fpath2) ] + [ (p-origin)*xaxis for p in fpath ]
-        ypath = [ (p-origin)*yaxis for p in reversed(fpath2) ] + [ (p-origin)*yaxis for p in fpath ]
+        xpath = [(p-origin)*xaxis for p in reversed(fpath2)] + \
+            [(p-origin)*xaxis for p in fpath]
+        ypath = [(p-origin)*yaxis for p in reversed(fpath2)] + \
+            [(p-origin)*yaxis for p in fpath]
 
         import matplotlib.pyplot as plt
-        import numpy as np        
+        import numpy as np
         plt.plot(xpath, ypath, 'o-', label='path')
         plt.plot(polyx, polyy, 'r-', label='polynomial')
 
@@ -325,7 +350,7 @@ def create_loads(nodes, tris, total_force, target, method='T'):
         (see Grosse et al. 2007, Techniques for Modeling Muscle-
         Induced Forces in Finite Element Models of Skeletal Structures
         https://doi.org/10.1002/ar.20568 )
-    
+
     method = 'U' (Ad Hoc Uniform Traction Model)
         "In this model, we applied uniform traction to the surfaces
         of finite elements that represent the muscle attachment
@@ -349,67 +374,68 @@ def create_loads(nodes, tris, total_force, target, method='T'):
     """
 
     # options
-    if method=='U':
+    if method == 'U':
         project = False
-        methodtxt='Ad Hoc Uniform Traction Model'
+        methodtxt = 'Ad Hoc Uniform Traction Model'
         normal_comp = None
-    elif method=='T':
-        project = True 
+    elif method == 'T':
+        project = True
         normal_comp = False
-        methodtxt='Tangential-Traction Model'
-    elif method=='T+N':
-        project = True 
+        methodtxt = 'Tangential-Traction Model'
+    elif method == 'T+N':
+        project = True
         normal_comp = True
-        methodtxt='Tangential-Plus-Normal-Traction Model'
+        methodtxt = 'Tangential-Plus-Normal-Traction Model'
     else:
-        raise Exception("unknown method: choose from 'uniform', 'tangential' or 'tangential+normal'")
-
+        raise Exception(
+            "unknown method: choose from 'uniform', 'tangential' or 'tangential+normal'")
 
     if normal_comp:
         muscle = SurfMesh(nodes, tris, vertices=False)
 
-    loads = [Pt([0,0,0]) for n in nodes ]
+    loads = [Pt([0, 0, 0]) for n in nodes]
 
     targetP = Pt(target)
 
     print(f'distributing the total force on triangles ({methodtxt})...')
-    total_area = 0.  
+    total_area = 0.
     for tri in tris:
-        p1, p2, p3 = [ Pt(nodes[tri[i]]) for i in range(3) ]
+        p1, p2, p3 = [Pt(nodes[tri[i]]) for i in range(3)]
         centre = (p1+p2+p3)/3       # barycentre
         e1 = p2-p1                  # edge 1
         e2 = p3-p1                  # edge 2
-        direct = targetP-centre 
+        direct = targetP-centre
         direct.normalize()          # direction of the traction (normalised)
         n = e1.cross(e2)            # normal vector
         area2 = abs(n)              # 2 x area
         area = area2/2              # triangle area
         total_area += area
         nunit = n/area2             # unit normal vector
-        
-        # we use a unit traction which will be scaled later, so that the total 
+
+        # we use a unit traction which will be scaled later, so that the total
         #   force is correctly prescibed
         traction = direct
         if project:
             ps = traction*nunit
-            if ps<0.0: # no line of sight => projection onto the element plane
+            if ps < 0.0:  # no line of sight => projection onto the element plane
                 traction = traction - ps*nunit
                 # tangential traction should be normalised if we want to follow
                 #   Grosse implementation
-                traction.normalize() # could fail if direct and nunit are aligned in opposite directions
+                traction.normalize()  # could fail if direct and nunit are aligned in opposite directions
 
                 # normal component (requires curvature and fiber length calculations)
                 if normal_comp:
-                    s, curv = compute_normal_force_factor(muscle, centre, nunit, targetP)
-                    if curv<0.0:
+                    s, curv = compute_normal_force_factor(
+                        muscle, centre, nunit, targetP)
+                    if curv < 0.0:
                         traction = traction + (s*curv)*nunit
 
         force = (area/3)*traction      # compute nodal force
 
         for i in range(3):          # 3 nodes per triangle
-            loads[tri[i]]+=force    # assembly
+            loads[tri[i]] += force    # assembly
 
-    print(f'\ttotal_area = {total_area}') # (OK: verified with meshlab)
+    print(f'\ttotal_area = {total_area}')  # (OK: verified with meshlab)
     print(f'\ttotal_force = {total_force}')
 
     loads = scale_loads(loads, total_force)
@@ -417,24 +443,24 @@ def create_loads(nodes, tris, total_force, target, method='T'):
     return loads
 
 
-
 class Plane:
     """construct a plane mesh and create associated mapper & actor
     (used to see the plane orientation - not used for calculations) 
     """
+
     def __init__(self, c, p1, p2):
         source = vtk.vtkPlaneSource()
         source.SetOrigin(c.x[0], c.x[1], c.x[2])
         source.SetPoint1(p1.x[0], p1.x[1], p1.x[2])
         source.SetPoint2(p2.x[0], p2.x[1], p2.x[2])
-        
+
         mapper = vtk.vtkPolyDataMapper()
-        mapper.SetInputConnection( source.GetOutputPort())
+        mapper.SetInputConnection(source.GetOutputPort())
 
         actor = vtk.vtkActor()
-        actor.SetMapper(mapper)        
-        actor.GetProperty().SetColor( colors.GetColor3d('green') )
-        actor.GetProperty().SetOpacity(0.5)        
+        actor.SetMapper(mapper)
+        actor.GetProperty().SetColor(colors.GetColor3d('green'))
+        actor.GetProperty().SetOpacity(0.5)
 
         self.source = source
         self.mapper = mapper
@@ -445,10 +471,11 @@ class ClipPolyData:
     """cut polydata with a plane in 2 sets of parts and keep 1 set
     input: point and normal vector defining the plane
     """
+
     def __init__(self, polydata, point, normal):
         plane = vtk.vtkPlane()
-        plane.SetOrigin(point.x[0],point.x[1],point.x[2] )
-        plane.SetNormal(normal.x[0],normal.x[1],normal.x[2])
+        plane.SetOrigin(point.x[0], point.x[1], point.x[2])
+        plane.SetNormal(normal.x[0], normal.x[1], normal.x[2])
 
         clip = vtk.vtkClipPolyData()
         clip.SetClipFunction(plane)
@@ -456,11 +483,11 @@ class ClipPolyData:
         clip.Update()
 
         mapper = vtk.vtkPolyDataMapper()
-        mapper.SetInputConnection( clip.GetOutputPort() )
+        mapper.SetInputConnection(clip.GetOutputPort())
 
         actor = vtk.vtkActor()
-        actor.SetMapper(mapper)        
-        actor.GetProperty().SetColor( colors.GetColor3d('orange') )
+        actor.SetMapper(mapper)
+        actor.GetProperty().SetColor(colors.GetColor3d('orange'))
 
         self.plane = plane
         self.clip = clip
@@ -472,6 +499,7 @@ class PlaneCut:
     """cut a polydata with a plane and keep the intersection line
     input: point and normal vector defining the plane
     """
+
     def __init__(self, polydata, point, normal):
 
         plane = vtk.vtkPlane()
@@ -484,11 +512,11 @@ class PlaneCut:
         cutter.Update()
 
         mapper = vtk.vtkPolyDataMapper()
-        mapper.SetInputConnection( cutter.GetOutputPort())
+        mapper.SetInputConnection(cutter.GetOutputPort())
 
         actor = vtk.vtkActor()
-        actor.SetMapper(mapper)        
-        actor.GetProperty().SetColor( colors.GetColor3d('green') )
+        actor.SetMapper(mapper)
+        actor.GetProperty().SetColor(colors.GetColor3d('green'))
         actor.GetProperty().SetLineWidth(2)
 
         self.plane = plane
@@ -500,19 +528,20 @@ class PlaneCut:
 class ClosestPart:
     """keep the part of a polydata which is closest to a given point
     """
+
     def __init__(self, outputport, point):
-        cfilter = vtk.vtkPolyDataConnectivityFilter() 
-        cfilter.SetExtractionModeToClosestPointRegion()	
-        cfilter.SetClosestPoint(point.x[0], point.x[1], point.x[2]) 
+        cfilter = vtk.vtkPolyDataConnectivityFilter()
+        cfilter.SetExtractionModeToClosestPointRegion()
+        cfilter.SetClosestPoint(point.x[0], point.x[1], point.x[2])
         cfilter.SetInputConnection(outputport)
         cfilter.Update()
 
         mapper = vtk.vtkPolyDataMapper()
-        mapper.SetInputConnection( cfilter.GetOutputPort())
+        mapper.SetInputConnection(cfilter.GetOutputPort())
 
         actor = vtk.vtkActor()
-        actor.SetMapper(mapper)        
-        actor.GetProperty().SetColor( colors.GetColor3d('yellow') )
+        actor.SetMapper(mapper)
+        actor.GetProperty().SetColor(colors.GetColor3d('yellow'))
         actor.GetProperty().SetLineWidth(4)
 
         self.cfilter = cfilter
@@ -539,10 +568,10 @@ def get_segments(poly):
     lines.InitTraversal()
     while lines.GetNextCell(idList):
         npts = idList.GetNumberOfIds()
-        if npts!=2:
+        if npts != 2:
             print("intersection contains bad cells")
             continue
-        ids.append(( idList.GetId(0),idList.GetId(1) ))
+        ids.append((idList.GetId(0), idList.GetId(1)))
     return pts, ids
 
 
@@ -551,10 +580,10 @@ def compute_length(pts, ids):
     input: pts: a list of points coordinates (Pt objects)
            ids: a list of tuples of indexing the 2 vertices of each segment
     """
-    length=0.
+    length = 0.
     for v in ids:
         dl = abs(pts[v[1]]-pts[v[0]])
-        length+=dl
+        length += dl
         # print(f'\tdl={dl}')
     return length
 
@@ -569,32 +598,32 @@ def sort_segments(pts, ids):
          the segments are connected so that they make a open piecewise-linear curve.
     """
     # build a list of segments for each vertex
-    segs_by_vertex = [ [] for p in pts ]
+    segs_by_vertex = [[] for p in pts]
     for v in ids:
-        segs_by_vertex[ v[0] ].append( v )
-        segs_by_vertex[ v[1] ].append( v )
-    # print(f'segs_by_vertex={segs_by_vertex}')    
+        segs_by_vertex[v[0]].append(v)
+        segs_by_vertex[v[1]].append(v)
+    # print(f'segs_by_vertex={segs_by_vertex}')
 
     # look for first segment
     #   this is the first segment which is connected to a single vertex
     #   curp = current vertex
     #   segs = list of segments connected to vertex #curp
     for curp, segs in enumerate(segs_by_vertex):
-        if len(segs)==1: # the vertex only has 1 connected segment
-            segment = segs[0] # this is the segment we are looking for
-            # the next point is the other vertex of the segment found   
+        if len(segs) == 1:  # the vertex only has 1 connected segment
+            segment = segs[0]  # this is the segment we are looking for
+            # the next point is the other vertex of the segment found
             nextp = segment[0]
-            if nextp==curp:
-                nextp=segment[1]
+            if nextp == curp:
+                nextp = segment[1]
             break
 
     # put the first segment is the sorted list
     #   reverse it if vertex curp is not the first one
     sorted_ids = []
     if segment[0] == curp:
-        sorted_ids.append( segment )
+        sorted_ids.append(segment)
     else:
-        sorted_ids.append( (segment[1], segment[0]) )
+        sorted_ids.append((segment[1], segment[0]))
 
     # print(f'segment={segment}')
     # print(f'nextp={nextp}')
@@ -602,14 +631,15 @@ def sort_segments(pts, ids):
     count = 0
     while True:
         curp = nextp                  # set current vertex
-        segs = segs_by_vertex[curp]   # retrieve the edges attached to this vertex
-        
-        # if there is only one segment attached to vertex #curp, 
+        # retrieve the edges attached to this vertex
+        segs = segs_by_vertex[curp]
+
+        # if there is only one segment attached to vertex #curp,
         #   the last one was the last one. There is no segment to add
-        if len(segs)==1:
+        if len(segs) == 1:
             break
         # otherwise, there should be 2 segments
-        assert(len(segs)==2)
+        assert(len(segs) == 2)
 
         # set "segment" to the segment different from the last one
         lastseg = segment   # keep track of the previous segment
@@ -618,20 +648,20 @@ def sort_segments(pts, ids):
             segment = segs[1]
         # add segment to the sorted list (and reverse it if necessary)
         if segment[0] == curp:
-            sorted_ids.append( segment )
+            sorted_ids.append(segment)
         else:
-            sorted_ids.append( (segment[1], segment[0]) )
+            sorted_ids.append((segment[1], segment[0]))
 
-        # the next point is the other vertex of the segment found 
+        # the next point is the other vertex of the segment found
         nextp = segment[0]
-        if nextp==curp:
-            nextp=segment[1]    
+        if nextp == curp:
+            nextp = segment[1]
 
-        count+=1
-        if count==len(pts):
+        count += 1
+        if count == len(pts):
             raise Exception('sort_segments: bad path (1)')
 
-    if len(sorted_ids)!=len(ids):
+    if len(sorted_ids) != len(ids):
         raise Exception('sort_segments: bad path (2)')
 
     return sorted_ids
@@ -649,13 +679,14 @@ def sort_vertices(pts, ids, centre):
 
     # check whether "centre" is the last vertex
     #   reverse the list if it is the case
-    if( abs(centre-pts[sorted_ids[-1][0]]) < abs(centre-pts[sorted_ids[0][0]]) ):
-        sorted_ids = [ (seg[1], seg[0]) for seg in reversed(sorted_ids)]
+    if(abs(centre-pts[sorted_ids[-1][0]]) < abs(centre-pts[sorted_ids[0][0]])):
+        sorted_ids = [(seg[1], seg[0]) for seg in reversed(sorted_ids)]
     # print(f'sorted_ids={sorted_ids}')
 
     # build sorted vertex indices
-    sorted_v = [ seg[0] for seg in sorted_ids ] # append first vertex of each segment
-    sorted_v.append( sorted_ids[-1][1] ) # append last vertex
+    # append first vertex of each segment
+    sorted_v = [seg[0] for seg in sorted_ids]
+    sorted_v.append(sorted_ids[-1][1])  # append last vertex
     # print(f'sorted_v={sorted_v}')
 
     return sorted_v
@@ -668,9 +699,9 @@ def build_clean_path(pts, sorted_v):
     # fill a list with all the (sorted) current coordinates
     fpath = []
     for iv in sorted_v:
-        fpath.append( pts[ iv ] )
+        fpath.append(pts[iv])
 
-    # Compute the longest segment length 
+    # Compute the longest segment length
     #   (used for adimensional tests later)
     dlmax = 0.0
     for i in range(len(fpath)-1):
@@ -684,7 +715,7 @@ def build_clean_path(pts, sorted_v):
     #   we delete it and start the loop again...
     #   This method is not the fastest but it works.
     while True:
-        to_remove=None
+        to_remove = None
         for i in range(1, len(fpath)-1):
             p0 = fpath[i-1]
             p1 = fpath[i]
@@ -694,15 +725,15 @@ def build_clean_path(pts, sorted_v):
             d1u = d1.normalized()
             d2u = d2.normalized()
             # print(f'd1u.cross(d2u)={abs(d1u.cross(d2u))}')
-            if(abs(d1u.cross(d2u))<1.e-2): # colinear
+            if(abs(d1u.cross(d2u)) < 1.e-2):  # colinear
                 to_remove = i
-                break                    
-            if abs(d1)<1.e-3*dlmax or abs(d1)<1.e-3*dlmax: # too close
+                break
+            if abs(d1) < 1.e-3*dlmax or abs(d1) < 1.e-3*dlmax:  # too close
                 to_remove = i
                 break
         # check whether a vertex should be removed or not
-        if to_remove!=None:
-            # print(f'removing vertex at {fpath[to_remove]}') 
+        if to_remove != None:
+            # print(f'removing vertex at {fpath[to_remove]}')
             del fpath[to_remove]
         else:
             break   # the path is clean
@@ -715,8 +746,8 @@ def build_local_axes(fpath, nunit):
     origin = fpath[0]
     p1 = fpath[1]     # next point
     xaxis = (p1-origin).normalized()
-    yaxis = nunit 
-    return origin, xaxis, yaxis   
+    yaxis = nunit
+    return origin, xaxis, yaxis
 
 
 def compute_curvature(fpath, fpath2, nunit):
@@ -728,7 +759,7 @@ def compute_curvature(fpath, fpath2, nunit):
     built from the position of the 2 first vertices of the 2 given paths
     """
     # build a list of points for the fitting procedure
-    #   we take point #1 and #2 of the 2 paths (point#0 is the shared vertex) 
+    #   we take point #1 and #2 of the 2 paths (point#0 is the shared vertex)
     #   if less points are found, the order of approximation is reduced
     fitpts = fpath2[2:0:-1] + fpath[1:3]
     # order of the polynomial approximation
@@ -738,12 +769,12 @@ def compute_curvature(fpath, fpath2, nunit):
     origin, xaxis, yaxis = build_local_axes(fpath, nunit)
 
     # use numpy.polyfit
-    import numpy as np  
-    x = [ (p-origin)*xaxis for p in fitpts ]
-    y = [ (p-origin)*yaxis for p in fitpts ]
+    import numpy as np
+    x = [(p-origin)*xaxis for p in fitpts]
+    y = [(p-origin)*yaxis for p in fitpts]
     z = np.polyfit(x, y, order)
     polynom = np.poly1d(z)
-    polyx = np.linspace(x[0],x[-1],100)  # only for display
+    polyx = np.linspace(x[0], x[-1], 100)  # only for display
     polyy = polynom(polyx)               # only for display
     # compute the curvature using first and second derivatives as in
     #   https://www.math24.net/curvature-radius
@@ -769,12 +800,12 @@ class SurfMesh:
         """create a VTK polydata mesh from nodes and triangles
         """
         polydata = vtk.vtkPolyData()
-        points   = vtk.vtkPoints()
-        polys    = vtk.vtkCellArray()
-        verts    = vtk.vtkCellArray()
+        points = vtk.vtkPoints()
+        polys = vtk.vtkCellArray()
+        verts = vtk.vtkCellArray()
 
         # print(f'len(nodes) = {len(nodes)}')
-        for i,x in enumerate(nodes):
+        for i, x in enumerate(nodes):
             points.InsertPoint(i, x[0], x[1], x[2])
             verts.InsertNextCell(1)
             verts.InsertCellPoint(i)
@@ -815,8 +846,8 @@ class SurfMesh:
         # manually compute vmin,vmax
         # vmin=+1.0e10
         # vmax=-1.0e10
-        for i,v in enumerate(vectors):
-            darray.InsertTuple3(i, v.x[0], v.x[1], v.x[2] )
+        for i, v in enumerate(vectors):
+            darray.InsertTuple3(i, v.x[0], v.x[1], v.x[2])
             n = abs(v)
             # if n<vmin: vmin=n
             # if n>vmax: vmax=n
@@ -825,13 +856,13 @@ class SurfMesh:
         # print(f'vmin, vmax={(vmin,vmax)}')
 
         self.polydata.GetPointData().AddArray(darray)
-        self.polydata.GetPointData().SetActiveVectors("forces")       
+        self.polydata.GetPointData().SetActiveVectors("forces")
 
     @staticmethod
     def load(off_file, vertices=True):
         fullpath = os.path.join(os.path.dirname(__file__), off_file)
         nodes, tris = load_msh(fullpath)
-        mesh = SurfMesh(nodes, tris, vertices)   
+        mesh = SurfMesh(nodes, tris, vertices)
         return mesh, nodes, tris
 
 
@@ -859,7 +890,7 @@ class Arrows:
         mapper.SetInputConnection(glyph.GetOutputPort())
         mapper.ScalarVisibilityOn()
 
-        vmin, vmax = polydata.GetPointData().GetVectors().GetRange(-1) # -1=norm
+        vmin, vmax = polydata.GetPointData().GetVectors().GetRange(-1)  # -1=norm
         mapper.SetScalarRange(vmin, vmax)
 
         actor = vtk.vtkActor()
@@ -869,7 +900,6 @@ class Arrows:
         self.glyph = glyph
         self.mapper = mapper
         self.actor = actor
-        
 
 
 # ------------------------------------------------------------------------------
@@ -877,7 +907,7 @@ class Arrows:
 class View:
     def __init__(self):
         self.ren = vtk.vtkRenderer()
-        self.ren.SetBackground(colors.GetColor3d('cobalt') )
+        self.ren.SetBackground(colors.GetColor3d('cobalt'))
 
         self.win = vtk.vtkRenderWindow()
         self.win.SetSize(800, 800)
@@ -898,13 +928,14 @@ class View:
             self.ren.AddActor(a)
 
     def interact(self):
-        self.ren.ResetCamera()        
+        self.ren.ResetCamera()
         self.intor.Initialize()
         self.win.Render()
         self.intor.Start()
 
 # ------------------------------------------------------------------------------
 
+
 class ParaviewAxes:
     "axes a la paraview"
 
@@ -917,8 +948,8 @@ class ParaviewAxes:
         axes.SetTotalLength(1, 1, 1)
         tprop = vtk.vtkTextProperty()
         tprop.ItalicOn()
-        #tprop.ShadowOn()
-        #tprop.SetFontFamilyToTimes()
+        # tprop.ShadowOn()
+        # tprop.SetFontFamilyToTimes()
         axes.GetXAxisCaptionActor2D().SetCaptionTextProperty(tprop)
         axes.GetYAxisCaptionActor2D().SetCaptionTextProperty(tprop)
         axes.GetZAxisCaptionActor2D().SetCaptionTextProperty(tprop)
@@ -935,9 +966,89 @@ class ParaviewAxes:
 # ------------------------------------------------------------------------------
 
 
-if __name__=="__main__":
+if __name__ == "__main__":
 
     Pt.test()
     # load_ply('Lmuscle.ply')
     # print(help(vtk.vtkCellArray))
     # import sys; sys.exit()
+
+    # basic implementation of "sweep & prune" algorithm.
+
+    # build a list of points with random coordinates
+    import random
+    all_pts = []
+    for i in range(10):
+        all_pts.append(Pt([random.random(), random.random(), random.random()]))
+
+    # build a list of selected points
+    pts = []
+    pts.append(all_pts[1])
+    pts.append(all_pts[2])
+    pts.append(all_pts[5])
+    pts.append(all_pts[8])
+
+    # add some variants of the selected pts
+    for p in pts[0:3]:
+        all_pts.append(Pt([p.x[0], 0.0, 1.0]))
+        all_pts.append(Pt([p.x[0], p.x[1], 1.0]))
+
+    print(f'all points:')
+    for no, p in enumerate(all_pts):
+        print(f'\t{no}: {p}')
+
+    print(f'selected:')
+    for p in pts:
+        print(f'\t{p}')
+
+    class Nod:
+        def __init__(self, x, idx=-1):  # selected points have  a -1 index
+            self.x = x
+            self.idx = idx
+
+        def __str__(self):
+            return f'{self.idx}: {self.x}'
+
+    # build node list
+    nods = []
+    for no, p in enumerate(all_pts):
+        nods.append(Nod(p.x, no))
+    for p in pts:
+        nods.append(Nod(p.x, -1))
+
+    # sort nodes according to their x coordinate
+    nods.sort(key=lambda nod: nod.x[0])
+
+    print('sorted list:')
+    for n in nods:
+        print(f'\t{n}')
+
+    print('sweep and prune:')
+    eps = 1e-3    # geometrical tolerance
+    ntests = 0
+    nnods = len(nods)
+    i = 0
+    while i < nnods:
+        n = nods[i]
+        x = n.x[0]
+        i2 = i
+        while True:
+            i2 += 1
+            if i2 == nnods:
+                break
+            n2 = nods[i2]
+            x2 = n2.x[0]
+            if abs(x2-x) > eps:  # test X
+                break
+            ntests += 1
+            if n.idx >= 0 and n2.idx >= 0:  # nodes from different sets
+                continue
+            if abs(n.x[1]-n2.x[1]) > eps:  # test Y
+                continue
+            if abs(n.x[2]-n2.x[2]) > eps:  # test Z
+                continue
+            no = max(n.idx, n2.idx)
+            print(f'{no} is selected')
+        i += 1
+    print(
+        f'sweep and prune done ({ntests} tests instead of {len(pts)*len(all_pts)})')
-- 
GitLab