|
|
|
@@ -0,0 +1,187 @@ |
|
|
|
import socket |
|
|
|
import select |
|
|
|
import ssl |
|
|
|
import struct |
|
|
|
import threading |
|
|
|
from threading import Thread |
|
|
|
import logging as log |
|
|
|
|
|
|
|
|
|
|
|
class SocketServer: |
|
|
|
|
|
|
|
NO_BUFFERING = 0 |
|
|
|
LENGTH_BUFFERING = 1 |
|
|
|
DELIMITER_BUFFERING = 2 |
|
|
|
|
|
|
|
def __init__(self, ssl_cert=None, ssl_key=None, buffering=LENGTH_BUFFERING, |
|
|
|
delimiter='\n'): |
|
|
|
self.terminated = False |
|
|
|
self.connected = False |
|
|
|
self.buffering = buffering |
|
|
|
self.ssl_cert, self.ssl_key = ssl_cert, ssl_key |
|
|
|
self.delimiter = delimiter |
|
|
|
log.debug("Initialized socket server") |
|
|
|
|
|
|
|
def connect(self, host, port): |
|
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
|
|
|
|
|
|
if self.ssl_cert and self.ssl_key: |
|
|
|
self.socket = ssl.wrap_socket(self.socket, |
|
|
|
certfile=self.ssl_cert, |
|
|
|
keyfile=self.ssl_key, |
|
|
|
server_side=True) |
|
|
|
try: |
|
|
|
self.socket.bind((host, port)) |
|
|
|
except socket.error as e: |
|
|
|
log.critical("Error: %s", str(e)) |
|
|
|
self.socket.listen(5) |
|
|
|
|
|
|
|
def start(self): |
|
|
|
Thread(target=self.wait_for_connections).start() |
|
|
|
|
|
|
|
def wait_for_connections(self): |
|
|
|
while not self.terminated: |
|
|
|
self.accept_connection() |
|
|
|
|
|
|
|
def accept_connection(self): |
|
|
|
if self.terminated: |
|
|
|
return |
|
|
|
try: |
|
|
|
conn, addr = self.socket.accept() |
|
|
|
except OSError as e: |
|
|
|
log.error("OSError while accepting connection: %s", str(e)) |
|
|
|
return |
|
|
|
Thread(target=self.pre_handle_client, args=(conn, addr)).start() |
|
|
|
|
|
|
|
def clean_up(self): |
|
|
|
self.terminated = True |
|
|
|
self.socket.shutdown(socket.SHUT_RDWR) |
|
|
|
self.socket.close() |
|
|
|
|
|
|
|
def pre_handle_client(self, conn, addr): |
|
|
|
log.info("Established connection to %s", addr) |
|
|
|
client = ClientConnection(conn, addr, self.buffering, self.delimiter) |
|
|
|
connected = self.on_connect(client) |
|
|
|
|
|
|
|
while not self.terminated and connected: |
|
|
|
try: |
|
|
|
read, write, error = select.select([client.conn, ], |
|
|
|
[client.conn, ], |
|
|
|
[], 5) |
|
|
|
except (select.error, ValueError): |
|
|
|
log.warn("FATAL ERROR: Connection error %s", addr) |
|
|
|
self.on_disconnect(client) |
|
|
|
break |
|
|
|
if len(read) > 0: |
|
|
|
text = client.recv() |
|
|
|
if text is None: |
|
|
|
log.critical("no text --> disconnecting") |
|
|
|
self.on_disconnect(client) |
|
|
|
break |
|
|
|
log.debug("Received: %s %s", text, addr) |
|
|
|
try: |
|
|
|
reply = self.on_receive(client, text) |
|
|
|
except TypeError: |
|
|
|
pass |
|
|
|
else: |
|
|
|
if reply is not None: |
|
|
|
client.send(reply) |
|
|
|
|
|
|
|
log.critical("Terminating connection with %s", addr) |
|
|
|
client.quit() |
|
|
|
|
|
|
|
def on_receive(self, client, query): |
|
|
|
return query |
|
|
|
|
|
|
|
def on_connect(self, client): |
|
|
|
return True |
|
|
|
|
|
|
|
def on_disconnect(self, client): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class ClientConnection: |
|
|
|
|
|
|
|
def __init__(self, conn, addr, buffering, delimiter): |
|
|
|
self.conn = conn |
|
|
|
self.addr = addr |
|
|
|
self.lock = threading.Lock() |
|
|
|
self.buffering, self.delimiter = buffering, delimiter |
|
|
|
self.left = b'' |
|
|
|
|
|
|
|
def ask(self, key, question): |
|
|
|
self.send(key, question) |
|
|
|
self.conn.settimeout(2) |
|
|
|
return self.recv() |
|
|
|
|
|
|
|
def recv(self): |
|
|
|
try: |
|
|
|
if self.buffering is SocketServer.DELIMITER_BUFFERING: |
|
|
|
recv, self.left = recvuntil(self.conn, self.delimiter, |
|
|
|
self.left) |
|
|
|
elif self.buffering is SocketServer.LENGTH_BUFFERING: |
|
|
|
raw_msglen = recvall(self.conn, 4) |
|
|
|
if not raw_msglen: |
|
|
|
log.debug("%s No msglen: EOF", self.addr) |
|
|
|
return None |
|
|
|
msglen = struct.unpack('>I', raw_msglen)[0] |
|
|
|
recv = recvall(self.conn, msglen) |
|
|
|
else: |
|
|
|
recv = self.conn.recv(1024) |
|
|
|
return recv.decode("utf-8") |
|
|
|
except Exception as e: |
|
|
|
log.critical("%s Error while receiving: %s", self.addr, str(e)) |
|
|
|
return None |
|
|
|
|
|
|
|
def send(self, query): |
|
|
|
if self.buffering is SocketServer.DELIMITER_BUFFERING: |
|
|
|
msg = (query + self.delimiter).encode() |
|
|
|
elif self.buffering is SocketServer.LENGTH_BUFFERING: |
|
|
|
msg = struct.pack('>I', len(query)) + query.encode() |
|
|
|
else: |
|
|
|
msg = query.encode() |
|
|
|
with self.lock: |
|
|
|
try: |
|
|
|
self.conn.sendall(msg) |
|
|
|
except Exception as e: |
|
|
|
log.debug("%s Failed to send %s: %s", self.addr, msg, e) |
|
|
|
|
|
|
|
def on_recv(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def quit(self): |
|
|
|
self.conn.close() |
|
|
|
|
|
|
|
|
|
|
|
def recvall(sock, n): |
|
|
|
data = b'' |
|
|
|
while len(data) < n: |
|
|
|
packet = sock.recv(n - len(data)) |
|
|
|
if not packet: |
|
|
|
return None |
|
|
|
data += packet |
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
def recvuntil(sock, delimiter, left_over): |
|
|
|
data = left_over |
|
|
|
left = b'' |
|
|
|
while True: |
|
|
|
packet = sock.recv(512) |
|
|
|
if not packet: |
|
|
|
return None |
|
|
|
data += packet |
|
|
|
idx = data.find(delimiter.encode()) |
|
|
|
if idx >= 0: |
|
|
|
data = data[:idx] |
|
|
|
left = data[idx+len(delimiter):] |
|
|
|
break |
|
|
|
return data, left |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
log.basicConfig(level=log.DEBUG) |
|
|
|
server = SocketServer(buffering=SocketServer.DELIMITER_BUFFERING, |
|
|
|
delimiter='\r\n') |
|
|
|
server.connect('localhost', 4242) |
|
|
|
server.start() |