Sample Code Illustrating a Secure RMI Connection

The example HelloClient.java illustrates how to create a secure Java Remote Method Invocation (RMI) connection. The sample code is basically a "Hello World" example modified to install and use a custom RMI socket factory. It uses RMI over an SSL transport layer using JSSE. The server runs HelloImpl.java, which sets up an internal RMI registry (rather than using the rmiregistry command). The client runs HelloClient and communicates over a secured connection.

Usage

Setting up this sample can be a little tricky; here are the necessary steps:

% javac *.java

% rmic HelloImpl

% java -Djava.security.policy=policy HelloImpl (run in another window)

% java HelloClient (run in another window)

For the server, the RMI security manager will be installed, and the supplied policy file grants permission to accept connections from any host. Obviously, giving all permissions should not be done in a production environment. You will need to give it the appropriate restrictive network privileges, such as:

permission java.net.SocketPermission "hostname:1024-", "accept,resolve";

In addition, this example can be easily updated to run with the standard SSL/TLS-based RMI socket factories. To do this, modify the HelloImpl.java file to use:

javax.rmi.ssl.SslRMIClientSocketFactory
javax.rmi.ssl.SslRMIServerSocketFactory

instead of:

RMISSLClientSocketFactory
RMISSLServerSocketFactory

These classes use SSLSocketFactory.getDefault() and SSLServerSocketFactory.getDefault(), so you will need to configure the system properly to locate your key and trust material.

Note:

If you use the standard SSL/TLS-based RMI socket factories, then you can specify the key stores with system properties:

-Djavax.net.ssl.keyStore=testkeys
-Djavax.net.ssl.keyStorePassword=passphrase

HelloClient.java

import java.net.InetAddress;
import java.rmi.RemoteException;
import java.rmi.registry.LocateRegistry;
import java.rmi.registry.Registry;

public class HelloClient {

    private static final int PORT = 2019;

    public static void main(String args[]) {
        try {
            // Make reference to SSL-based registry
            Registry registry = LocateRegistry.getRegistry(
                InetAddress.getLocalHost().getHostName(), PORT,
                new RMISSLClientSocketFactory());

            // "obj" is the identifier that we'll use to refer
            // to the remote object that implements the "Hello"
            // interface
            Hello obj = (Hello) registry.lookup("HelloServer");

            String message = "blank";
            message = obj.sayHello();
            System.out.println(message+"\n");
        } catch (Exception e) {
            System.out.println("HelloClient exception: " + e.getMessage());
            e.printStackTrace();
        }
    }
}

Hello.java

import java.rmi.Remote;
import java.rmi.RemoteException;

public interface Hello extends Remote {
    String sayHello() throws RemoteException;
}

HelloImpl.java

import java.io.*;
import java.net.InetAddress;
import java.rmi.RemoteException;
import java.rmi.RMISecurityManager;
import java.rmi.registry.LocateRegistry;
import java.rmi.registry.Registry;
import java.rmi.server.UnicastRemoteObject;

public class HelloImpl extends UnicastRemoteObject implements Hello {

    private static final int PORT = 2019;

    public HelloImpl() throws Exception {
        super(PORT,
              new RMISSLClientSocketFactory(),
              new RMISSLServerSocketFactory());
    }

    public String sayHello() {
        return "Hello World!";
    }

    public static void main(String args[]) {

        // Create and install a security manager
        if (System.getSecurityManager() == null) {
            System.setSecurityManager(new RMISecurityManager());
        }

        try {
            // Create SSL-based registry
            Registry registry = LocateRegistry.createRegistry(PORT,
                new RMISSLClientSocketFactory(),
                new RMISSLServerSocketFactory());

            HelloImpl obj = new HelloImpl();

            // Bind this object instance to the name "HelloServer"
            registry.bind("HelloServer", obj);

            System.out.println("HelloServer bound in registry");
        } catch (Exception e) {
            System.out.println("HelloImpl err: " + e.getMessage());
            e.printStackTrace();
        }
    }
}

RMISSLClientSocketFactory.java

import java.io.*;
import java.net.*;
import java.rmi.server.*;
import javax.net.ssl.*;

public class RMISSLClientSocketFactory
        implements RMIClientSocketFactory, Serializable {

    public Socket createSocket(String host, int port) throws IOException {
            SSLSocketFactory factory =
                (SSLSocketFactory)SSLSocketFactory.getDefault();
            SSLSocket socket = (SSLSocket)factory.createSocket(host, port);
            return socket;
    }

    public int hashCode() {
        return getClass().hashCode();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        } else if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return true;
    }
}

RMISSLServerSocketFactory.java

import java.io.*;
import java.net.*;
import java.rmi.server.*;
import javax.net.ssl.*;
import java.security.KeyStore;
import javax.net.ssl.*;

public class RMISSLServerSocketFactory implements RMIServerSocketFactory {

    /*
     * Create one SSLServerSocketFactory, so we can reuse sessions
     * created by previous sessions of this SSLContext.
     */
    private SSLServerSocketFactory ssf = null;

    public RMISSLServerSocketFactory() throws Exception {
        try {
            // set up key manager to do server authentication
            SSLContext ctx;
            KeyManagerFactory kmf;
            KeyStore ks;

            char[] passphrase = "passphrase".toCharArray();
            ks = KeyStore.getInstance("JKS");
            ks.load(new FileInputStream("testkeys"), passphrase);

            kmf = KeyManagerFactory.getInstance("SunX509");
            kmf.init(ks, passphrase);

            ctx = SSLContext.getInstance("TLS");
            ctx.init(kmf.getKeyManagers(), null, null);

            ssf = ctx.getServerSocketFactory();
        } catch (Exception e) {
            e.printStackTrace();
            throw e;
        }
    }

    public ServerSocket createServerSocket(int port) throws IOException {
            return ssf.createServerSocket(port);
    }

    public int hashCode() {
        return getClass().hashCode();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        } else if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return true;
    }
}

policy

// In this example, for simplicity, we will use a policy file that
// gives global permission to anyone from anywhere. Do not use this
// policy file in a production environment.

grant {
	permission java.net.SocketPermission "*", "accept,resolve";
};