diff --git a/src/org/openstreetmap/josm/actions/JoinNodeWayAction.java b/src/org/openstreetmap/josm/actions/JoinNodeWayAction.java
index 9453f6c..3f2623c 100644
--- a/src/org/openstreetmap/josm/actions/JoinNodeWayAction.java
+++ b/src/org/openstreetmap/josm/actions/JoinNodeWayAction.java
@@ -18,11 +18,13 @@ import org.openstreetmap.josm.tools.Shortcut;
 
 import java.awt.event.ActionEvent;
 import java.awt.event.KeyEvent;
+import java.util.Comparator;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.HashMap;
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
@@ -65,58 +67,82 @@ public class JoinNodeWayAction extends JosmAction {
         if (!isEnabled())
             return;
         Collection<Node> selectedNodes = getCurrentDataSet().getSelectedNodes();
-        // Allow multiple selected nodes too?
-        if (selectedNodes.size() != 1) return;
-
-        final Node node = selectedNodes.iterator().next();
-
         Collection<Command> cmds = new LinkedList<>();
+        Map<Way, MultiMap<Integer, Node>> data = new HashMap<>();
 
         // If the user has selected some ways, only join the node to these.
         boolean restrictToSelectedWays =
                 !getCurrentDataSet().getSelectedWays().isEmpty();
 
-        List<WaySegment> wss = Main.map.mapView.getNearestWaySegments(
-                Main.map.mapView.getPoint(node), OsmPrimitive.isSelectablePredicate);
-        MultiMap<Way, Integer> insertPoints = new MultiMap<>();
-        for (WaySegment ws : wss) {
-            // Maybe cleaner to pass a "isSelected" predicate to getNearestWaySegments, but this is less invasive.
-            if (restrictToSelectedWays && !ws.way.isSelected()) {
-                continue;
-            }
+        // Planning phase: decide where we'll insert the nodes and put it all in "data"
+        for (Node node : selectedNodes) {
+            List<WaySegment> wss = Main.map.mapView.getNearestWaySegments(
+                    Main.map.mapView.getPoint(node), OsmPrimitive.isSelectablePredicate);
+
+            MultiMap<Way, Integer> insertPoints = new MultiMap<>();
+            for (WaySegment ws : wss) {
+                // Maybe cleaner to pass a "isSelected" predicate to getNearestWaySegments, but this is less invasive.
+                if (restrictToSelectedWays && !ws.way.isSelected()) {
+                    continue;
+                }
 
-            if (ws.getFirstNode() != node && ws.getSecondNode() != node) {
-                insertPoints.put(ws.way, ws.lowerIndex);
+                if (ws.getFirstNode() != node && ws.getSecondNode() != node) {
+                    insertPoints.put(ws.way, ws.lowerIndex);
+                }
+            }
+            for (Map.Entry<Way, Set<Integer>> entry : insertPoints.entrySet()) {
+                final Way w = entry.getKey();
+                final Set<Integer> insertPointsForWay = entry.getValue();
+                for (int i : pruneSuccs(insertPointsForWay)) {
+                    MultiMap<Integer, Node> innerMap;
+                    if (!data.containsKey(w)) {
+                        innerMap = new MultiMap<>();
+                    } else {
+                        innerMap = data.get(w);
+                    }
+                    innerMap.put(i, node);
+                    data.put(w, innerMap);
+                }
             }
         }
 
-        for (Map.Entry<Way, Set<Integer>> entry : insertPoints.entrySet()) {
+        // Execute phase: traverse the structure "data" and finally put the nodes into place
+        for (Map.Entry<Way, MultiMap<Integer, Node>> entry : data.entrySet()) {
             final Way w = entry.getKey();
-            final Set<Integer> insertPointsForWay = entry.getValue();
-            if (insertPointsForWay.isEmpty()) {
-                continue;
-            }
+            final MultiMap<Integer, Node> innerEntry = entry.getValue();
+
+            List<Integer> segmentIndexes = new LinkedList<Integer>();
+            segmentIndexes.addAll(innerEntry.keySet());
+            Collections.sort(segmentIndexes, Collections.reverseOrder());
 
-            List<Node> nodesToAdd = w.getNodes();
-            for (int i : pruneSuccsAndReverse(insertPointsForWay)) {
+            List<Node> wayNodes = w.getNodes();
+            for (Integer segmentIndex : segmentIndexes) {
+                final Set<Node> nodesInSegment = innerEntry.get(segmentIndex);
                 if (joinWayToNode) {
-                    EastNorth newPosition = Geometry.closestPointToSegment(
-                            w.getNode(i).getEastNorth(), w.getNode(i + 1).getEastNorth(), node.getEastNorth());
-                    cmds.add(new MoveCommand(node, Projections.inverseProject(newPosition)));
+                    for (Node node : nodesInSegment) {
+                        EastNorth newPosition = Geometry.closestPointToSegment(w.getNode(segmentIndex).getEastNorth(),
+                                                                            w.getNode(segmentIndex+1).getEastNorth(),
+                                                                            node.getEastNorth());
+                        cmds.add(new MoveCommand(node, Projections.inverseProject(newPosition)));
+                    }
                 }
-                nodesToAdd.add(i + 1, node);
+                List<Node> nodesToAdd = new LinkedList<Node>();
+                nodesToAdd.addAll(nodesInSegment);
+                Collections.sort(nodesToAdd, new nodeDistanceToRefNodeComparator(w.getNode(segmentIndex), w.getNode(segmentIndex+1), !joinWayToNode));
+                wayNodes.addAll(segmentIndex + 1, nodesToAdd);
             }
             Way wnew = new Way(w);
-            wnew.setNodes(nodesToAdd);
+            wnew.setNodes(wayNodes);
             cmds.add(new ChangeCommand(w, wnew));
         }
+
         if (cmds.isEmpty()) return;
         Main.main.undoRedo.add(new SequenceCommand(getValue(NAME).toString(), cmds));
         Main.map.repaint();
     }
 
-    private static SortedSet<Integer> pruneSuccsAndReverse(Collection<Integer> is) {
-        SortedSet<Integer> is2 = new TreeSet<>(Collections.reverseOrder());
+    private static SortedSet<Integer> pruneSuccs(Collection<Integer> is) {
+        SortedSet<Integer> is2 = new TreeSet<>();
         for (int i : is) {
             if (!is2.contains(i - 1) && !is2.contains(i + 1)) {
                 is2.add(i);
@@ -125,6 +151,39 @@ public class JoinNodeWayAction extends JosmAction {
         return is2;
     }
 
+    // Sorts collinear nodes by their distance to a common reference node.
+    private class nodeDistanceToRefNodeComparator implements Comparator<Node> {
+        private EastNorth refPoint;
+        private EastNorth refPoint2;
+        private boolean projectToSegment;
+        nodeDistanceToRefNodeComparator(Node referenceNode) {
+            refPoint = referenceNode.getEastNorth();
+            projectToSegment = false;
+        }
+        nodeDistanceToRefNodeComparator(Node referenceNode, Node referenceNode2, boolean projectFirst) {
+            refPoint = referenceNode.getEastNorth();
+            refPoint2 = referenceNode2.getEastNorth();
+            projectToSegment = projectFirst;
+        }
+        public int compare(Node first, Node second) {
+            EastNorth firstPosition = first.getEastNorth();
+            EastNorth secondPosition = second.getEastNorth();
+
+            if (projectToSegment) {
+                firstPosition = Geometry.closestPointToSegment(refPoint, refPoint2, firstPosition);
+                secondPosition = Geometry.closestPointToSegment(refPoint, refPoint2, secondPosition);
+            }
+
+            double distanceFirst = firstPosition.distance(refPoint);
+            double distanceSecond = secondPosition.distance(refPoint);
+            double difference =  distanceFirst - distanceSecond;
+
+            if (difference > 0.0) return 1;
+            if (difference < 0.0) return -1;
+            return 0;
+        }
+    }
+
     @Override
     protected void updateEnabledState() {
         if (getCurrentDataSet() == null) {
