import socket import struct import ssl from threading import Thread import select class SocketClient: NO_BUFFERING = 0 LENGTH_BUFFERING = 1 DELIMITER_BUFFERING = 2 def __init__(self, ssl_cert=None, buffering=LENGTH_BUFFERING, delimiter='\n'): self.terminated = False self.connected = False self.ssl_cert = ssl_cert self.buffering = buffering self.delimiter = delimiter self.left = b'' print("Initialized socket client") def connect(self, host, port): try: self.socket = socket.socket() # if using a ssl certificate, wrap the socket using ssl if self.ssl_cert: self.socket = ssl.wrap_socket(self.socket, ca_certs=self.ssl_cert, cert_reqs=ssl.CERT_REQUIRED) self.socket.connect((host, port)) self.connected = True return True except socket.error: print("ERROR: Error while connecting!") self.connected = False return False def start(self): print("Client has started") Thread(target=self.handle_server).start() def disconnect(self): print("disconnecting") self.connected = False self.send("quit") self.terminated = True self.socket.close() def send(self, data): try: if self.buffering is SocketClient.DELIMITER_BUFFERING: msg = (data + self.delimiter).encode() elif self.buffering is SocketClient.LENGTH_BUFFERING: msg = struct.pack('>I', len(data)) + data.encode() else: msg = data.encode() self.socket.sendall(msg) except Exception as e: print("ERROR: Error while sending", e) def handle_server(self): while not self.terminated: try: read_sockets, write_sockets, in_error = \ select.select([self.socket, ], [self.socket, ], [], 5) except select.error: print("FATAL ERROR: Connection error") self.socket.shutdown(2) self.socket.close() if len(read_sockets) > 0: if self.buffering is SocketClient.DELIMITER_BUFFERING: recv, self.left = recvuntil(self.socket, self.delimiter, self.left) elif self.buffering is SocketClient.LENGTH_BUFFERING: raw_msglen = recvall(self.socket, 4) if not raw_msglen: print("Connection closed") self.on_quit() return None msglen = struct.unpack('>I', raw_msglen)[0] recv = recvall(self.socket, msglen) elif self.buffering is SocketClient.NO_BUFFERING: recv = self.socket.recv(2048) if len(recv) > 0: reply = self.on_receive(recv) if len(write_sockets) > 0 and reply: self.socket.send(reply) print("finished handling server") def on_receive(self, query): pass def on_quit(self): pass def recvall(sock, n): data = b'' while len(data) < n: packet = sock.recv(n - len(data)) if not packet: return None data += packet return data def recvuntil(sock, delimiter, left_over): data = left_over left = b'' while True: packet = sock.recv(512) if not packet: return None data += packet idx = data.find(delimiter.encode()) if idx >= 0: data = data[:idx] left = data[idx+len(delimiter):] break return data, left if __name__ == '__main__': client = SocketClient(buffering=SocketClient.DELIMITER_BUFFERING, delimiter='\r\n') client.connect('localhost', 4242) client.start() client.on_receive = lambda x: print("-->", x) while True: x = input("> ") client.send(x)