diff --git a/app/src/main/java/network/grape/service/GrapeVpnService.java b/app/src/main/java/network/grape/service/GrapeVpnService.java index 9c45da30..c95b6a2c 100644 --- a/app/src/main/java/network/grape/service/GrapeVpnService.java +++ b/app/src/main/java/network/grape/service/GrapeVpnService.java @@ -11,6 +11,10 @@ import java.net.Socket; import java.net.UnknownHostException; import java.nio.ByteBuffer; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import network.grape.lib.PacketHeaderException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -132,7 +136,9 @@ void startTrafficHandler() throws IOException { SessionHandler handler = new SessionHandler(sessionManager, new SocketProtector(this)); // background thread for writing output to the vpn outputstream - vpnWriter = new VpnWriter(clientWriter, sessionManager); + final BlockingQueue taskQueue = new LinkedBlockingQueue<>(); + ThreadPoolExecutor executor = new ThreadPoolExecutor(8, 100, 10, TimeUnit.SECONDS, taskQueue); + vpnWriter = new VpnWriter(clientWriter, sessionManager, executor); vpnWriterThread = new Thread(vpnWriter); vpnWriterThread.start(); diff --git a/app/src/main/java/network/grape/service/VpnWriter.java b/app/src/main/java/network/grape/service/VpnWriter.java index 845db944..882ce469 100644 --- a/app/src/main/java/network/grape/service/VpnWriter.java +++ b/app/src/main/java/network/grape/service/VpnWriter.java @@ -36,19 +36,6 @@ public class VpnWriter implements Runnable { private ThreadPoolExecutor workerPool; private volatile boolean running; - /** - * Construct a new VpnWriter. - * - * @param outputStream the stream to write back into the VPN interface. - */ - public VpnWriter(FileOutputStream outputStream, SessionManager sessionManager) { - this.logger = LoggerFactory.getLogger(VpnWriter.class); - this.outputStream = outputStream; - this.sessionManager = sessionManager; - final BlockingQueue taskQueue = new LinkedBlockingQueue<>(); - workerPool = new ThreadPoolExecutor(8, 100, 10, TimeUnit.SECONDS, taskQueue); - } - /** * Construct a new VpnWriter with the workerpool provided. * @@ -63,6 +50,14 @@ public VpnWriter(FileOutputStream outputStream, SessionManager sessionManager, this.workerPool = workerPool; } + boolean isRunning() { + return running; + } + + boolean notRunning() { + return !running; + } + /** * Main thread for the VpnWriter. */ @@ -70,7 +65,7 @@ public void run() { logger.info("VpnWriter starting in the background"); selector = sessionManager.getSelector(); running = true; - while (running) { + while (isRunning()) { // first just try to wait for a socket to be ready for a connect, read, etc try { @@ -87,7 +82,7 @@ public void run() { continue; } - if (!running) { + if (notRunning()) { break; } @@ -103,7 +98,7 @@ public void run() { processUdpSelectionKey(key); } iterator.remove(); - if (!running) { + if (notRunning()) { break; } } diff --git a/app/src/test/java/network/grape/service/VpnWriterTest.java b/app/src/test/java/network/grape/service/VpnWriterTest.java index 1a0010ae..d68fba6d 100644 --- a/app/src/test/java/network/grape/service/VpnWriterTest.java +++ b/app/src/test/java/network/grape/service/VpnWriterTest.java @@ -10,12 +10,20 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; + import java.io.FileOutputStream; import java.io.IOException; import java.net.DatagramSocket; import java.net.InetAddress; import java.nio.channels.DatagramChannel; +import java.nio.channels.FileChannel; import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; import java.util.concurrent.ThreadPoolExecutor; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -127,12 +135,9 @@ public void sessionConnected() throws IOException { @Test public void testProcessSelector() { - FileOutputStream fileOutputStream = mock(FileOutputStream.class); Session session = mock(Session.class); SelectionKey selectionKey = mock(SelectionKey.class); - ThreadPoolExecutor workerPool = mock(ThreadPoolExecutor.class); - SessionManager sessionManager = mock(SessionManager.class); - VpnWriter vpnWriter = new VpnWriter(fileOutputStream, sessionManager, workerPool); + VpnWriter vpnWriter = spy(new VpnWriter(fileOutputStream, sessionManager, workerPool)); when(selectionKey.isValid()).thenReturn(false); vpnWriter.processSelector(selectionKey, session); @@ -185,4 +190,107 @@ public void testProcessSelector() { vpnWriter.processSelector(selectionKey, session); verify(session, Mockito.times(1)).setBusyWrite(true); } + + @Test + public void runTest() throws InterruptedException, IOException { + + // base case, nothing back from the selector + VpnWriter vpnWriter = spy(new VpnWriter(fileOutputStream, sessionManager, workerPool)); + Selector selector = mock(Selector.class); + when(sessionManager.getSelector()).thenReturn(selector); + when(vpnWriter.isRunning()).thenReturn(true).thenReturn(false); + Thread t = new Thread(vpnWriter); + t.start(); + t.join(); + + // exception on select + vpnWriter = spy(new VpnWriter(fileOutputStream, sessionManager, workerPool)); + selector = mock(Selector.class); + when(selector.select()).thenThrow(IOException.class); + when(sessionManager.getSelector()).thenReturn(selector); + when(vpnWriter.isRunning()).thenReturn(true).thenReturn(false); + t = new Thread(vpnWriter); + t.start(); + t.join(); + + // exception on select + interrupt in handler + vpnWriter = spy(new VpnWriter(fileOutputStream, sessionManager, workerPool)); + selector = mock(Selector.class); + when(selector.select()).thenThrow(IOException.class); + when(sessionManager.getSelector()).thenReturn(selector); + when(vpnWriter.isRunning()).thenReturn(true).thenReturn(false); + t = new Thread(vpnWriter); + t.start(); + // there is a chance thread scheduling will be bad and the interrupted exception be thrown + // in time here, but its okay. + Thread.sleep(100); + t.interrupt(); + t.join(); + } + + @Test public void runTestSelectionSet() throws InterruptedException { + // non-empty iterator + Set selectionKeySet = new HashSet<>(); + SelectionKey udpKey = mock(SelectionKey.class); + DatagramChannel udpChannel = mock(DatagramChannel.class); + when(udpKey.channel()).thenReturn(udpChannel); + selectionKeySet.add(udpKey); + + SelectionKey tcpKey = mock(SelectionKey.class); + SocketChannel tcpChannel = mock(SocketChannel.class); + when(tcpKey.channel()).thenReturn(tcpChannel); + selectionKeySet.add(tcpKey); + + SelectionKey serverSocketKey = mock(SelectionKey.class); + ServerSocketChannel serverSocketChannel = mock(ServerSocketChannel.class); + when(serverSocketKey.channel()).thenReturn(serverSocketChannel); + selectionKeySet.add(serverSocketKey); + + VpnWriter vpnWriter = spy(new VpnWriter(fileOutputStream, sessionManager, workerPool)); + doNothing().when(vpnWriter).processUdpSelectionKey(any()); + + Selector selector = mock(Selector.class); + when(sessionManager.getSelector()).thenReturn(selector); + when(selector.selectedKeys()).thenReturn(selectionKeySet); + Thread t = new Thread(vpnWriter); + t.start(); + when(vpnWriter.isRunning()).thenReturn(true).thenReturn(false); + when(vpnWriter.notRunning()).thenReturn(false); + vpnWriter.shutdown(); + t.join(); + } + + @Test public void runTestNotRunning() throws InterruptedException { + // base case, nothing back from the selector + VpnWriter vpnWriter = spy(new VpnWriter(fileOutputStream, sessionManager, workerPool)); + Selector selector = mock(Selector.class); + when(sessionManager.getSelector()).thenReturn(selector); + when(vpnWriter.isRunning()).thenReturn(true).thenReturn(false); + when(vpnWriter.notRunning()).thenReturn(true); + Thread t = new Thread(vpnWriter); + t.start(); + t.join(); + } + + @Test public void runTestNotRunningNonEmptyIterator() throws InterruptedException { + // non-empty iterator + Set selectionKeySet = new HashSet<>(); + SelectionKey udpKey = mock(SelectionKey.class); + DatagramChannel udpChannel = mock(DatagramChannel.class); + when(udpKey.channel()).thenReturn(udpChannel); + selectionKeySet.add(udpKey); + + VpnWriter vpnWriter = spy(new VpnWriter(fileOutputStream, sessionManager, workerPool)); + doNothing().when(vpnWriter).processUdpSelectionKey(any()); + + Selector selector = mock(Selector.class); + when(sessionManager.getSelector()).thenReturn(selector); + when(selector.selectedKeys()).thenReturn(selectionKeySet); + Thread t = new Thread(vpnWriter); + t.start(); + when(vpnWriter.isRunning()).thenReturn(true).thenReturn(false); + when(vpnWriter.notRunning()).thenReturn(false).thenReturn(true); + vpnWriter.shutdown(); + t.join(); + } }