|
- import socket
- import select
- import ssl
- import struct
- import threading
- from threading import Thread
- import logging as log
-
-
- class SocketServer:
-
- NO_BUFFERING = 0
- LENGTH_BUFFERING = 1
- DELIMITER_BUFFERING = 2
-
- def __init__(self, ssl_cert=None, ssl_key=None, buffering=LENGTH_BUFFERING,
- delimiter='\n'):
- self.terminated = False
- self.connected = False
- self.buffering = buffering
- self.ssl_cert, self.ssl_key = ssl_cert, ssl_key
- self.delimiter = delimiter
- log.debug("Initialized socket server")
-
- def connect(self, host, port):
- self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-
- if self.ssl_cert and self.ssl_key:
- self.socket = ssl.wrap_socket(self.socket,
- certfile=self.ssl_cert,
- keyfile=self.ssl_key,
- server_side=True)
- try:
- self.socket.bind((host, port))
- except socket.error as e:
- log.critical("Error: %s", str(e))
- self.socket.listen(5)
-
- def start(self):
- Thread(target=self.wait_for_connections).start()
-
- def wait_for_connections(self):
- while not self.terminated:
- self.accept_connection()
-
- def accept_connection(self):
- if self.terminated:
- return
- try:
- conn, addr = self.socket.accept()
- except OSError as e:
- log.error("OSError while accepting connection: %s", str(e))
- return
- Thread(target=self.pre_handle_client, args=(conn, addr)).start()
-
- def clean_up(self):
- self.terminated = True
- self.socket.shutdown(socket.SHUT_RDWR)
- self.socket.close()
-
- def pre_handle_client(self, conn, addr):
- log.info("Established connection to %s", addr)
- client = ClientConnection(conn, addr, self.buffering, self.delimiter)
- connected = self.on_connect(client)
-
- while not self.terminated and connected:
- try:
- read, write, error = select.select([client.conn, ],
- [client.conn, ],
- [], 5)
- except (select.error, ValueError):
- log.warn("FATAL ERROR: Connection error %s", addr)
- self.on_disconnect(client)
- break
- if len(read) > 0:
- text = client.recv()
- if text is None:
- log.critical("no text --> disconnecting")
- self.on_disconnect(client)
- break
- log.debug("Received: %s %s", text, addr)
- try:
- reply = self.on_receive(client, text)
- except TypeError:
- pass
- else:
- if reply is not None:
- client.send(reply)
-
- log.critical("Terminating connection with %s", addr)
- client.quit()
-
- def on_receive(self, client, query):
- return query
-
- def on_connect(self, client):
- return True
-
- def on_disconnect(self, client):
- pass
-
-
- class ClientConnection:
-
- def __init__(self, conn, addr, buffering, delimiter):
- self.conn = conn
- self.addr = addr
- self.lock = threading.Lock()
- self.buffering, self.delimiter = buffering, delimiter
- self.left = b''
-
- def ask(self, key, question):
- self.send(key, question)
- self.conn.settimeout(2)
- return self.recv()
-
- def recv(self):
- try:
- if self.buffering is SocketServer.DELIMITER_BUFFERING:
- recv, self.left = recvuntil(self.conn, self.delimiter,
- self.left)
- elif self.buffering is SocketServer.LENGTH_BUFFERING:
- raw_msglen = recvall(self.conn, 4)
- if not raw_msglen:
- log.debug("%s No msglen: EOF", self.addr)
- return None
- msglen = struct.unpack('>I', raw_msglen)[0]
- recv = recvall(self.conn, msglen)
- else:
- recv = self.conn.recv(1024)
- return recv.decode("utf-8")
- except Exception as e:
- log.critical("%s Error while receiving: %s", self.addr, str(e))
- return None
-
- def send(self, query):
- if self.buffering is SocketServer.DELIMITER_BUFFERING:
- msg = (query + self.delimiter).encode()
- elif self.buffering is SocketServer.LENGTH_BUFFERING:
- msg = struct.pack('>I', len(query)) + query.encode()
- else:
- msg = query.encode()
- with self.lock:
- try:
- self.conn.sendall(msg)
- except Exception as e:
- log.debug("%s Failed to send %s: %s", self.addr, msg, e)
-
- def on_recv(self):
- pass
-
- def quit(self):
- self.conn.close()
-
-
- def recvall(sock, n):
- data = b''
- while len(data) < n:
- packet = sock.recv(n - len(data))
- if not packet:
- return None
- data += packet
- return data
-
-
- def recvuntil(sock, delimiter, left_over):
- data = left_over
- left = b''
- while True:
- packet = sock.recv(512)
- if not packet:
- return None
- data += packet
- idx = data.find(delimiter.encode())
- if idx >= 0:
- data = data[:idx]
- left = data[idx+len(delimiter):]
- break
- return data, left
-
- if __name__ == '__main__':
- log.basicConfig(level=log.DEBUG)
- server = SocketServer(buffering=SocketServer.DELIMITER_BUFFERING,
- delimiter='\r\n')
- server.connect('localhost', 4242)
- server.start()
|