diff --git a/pywebsrv.py b/pywebsrv.py index 1b66091..74496bd 100644 --- a/pywebsrv.py +++ b/pywebsrv.py @@ -114,7 +114,8 @@ class FileHandler: "allow-localhost", "disable-autocertgen", "key-file", - "cert-file" + "cert-file", + "block-ua" ] if option not in valid_options: return None @@ -131,6 +132,19 @@ class FileHandler: if option == "host": seperated_values = value.split(",", -1) return [value.lower() for value in seperated_values] + if option == "block-ua": + seperated_values = value.split(",", -1) + host_to_match = [] + literal_blocks = [] + for val in seperated_values: + if val.startswith("match(") and val.endswith(")"): + idx = val.index("(") + idx2 = val.index(")") + ua_to_match = val[idx+1:idx2] + host_to_match.append(ua_to_match) + else: + literal_blocks.append(val) + return host_to_match, literal_blocks if option == "port" or option == "port-https": return int(value) if ( @@ -172,6 +186,16 @@ class RequestParser: path += "index.html" return method, path, version + def ua_blocker(self, ua): + """Parses and matches UA to block""" + match, literal = self.file_handler.read_config("block-ua") + if ua in literal: + return False + for _ua in match: + if _ua.lower() in ua.lower(): + return False + return True + def is_method_allowed(self, method): """ Checks if the HTTP method is allowed. @@ -303,7 +327,6 @@ class WebServer: while self.running: try: conn, addr = self.http_socket.accept() - print(f"HTTP connection received from {addr}") self.handle_connection(conn, addr) except Exception as e: print(f"HTTP error: {e}") @@ -316,7 +339,6 @@ class WebServer: while self.running: try: conn, addr = self.https_socket.accept() - print(f"HTTPS connection received from {addr}") self.handle_connection(conn, addr) except Exception as e: print( @@ -329,7 +351,12 @@ class WebServer: try: data = conn.recv(512) request = data.decode(errors="ignore") - response = self.handle_request(request, addr) + if not data: + response = self.build_response(400, "Bad Request") # user did fucky-wucky + elif len(data) > 8192: + response = self.build_response(413, "Request too long") + else: + response = self.handle_request(request, addr) if isinstance(response, str): response = response.encode() @@ -341,12 +368,7 @@ class WebServer: conn.close() def handle_request(self, data, addr): - print(f"len data: {len(data)}") - if not data: - return self.build_response(400, "Bad Request") # user did fucky-wucky - if len(data) > 8192: - return self.build_response(413, "Request too long") - + print(f"data: {data}") request_line = data.splitlines()[0] # Extract host from headers, never works though @@ -364,17 +386,32 @@ class WebServer: 400, self.no_host_req_response.encode() ) + for line in data.splitlines(): + if "User-Agent" in line: + ua = line.split(":", 1)[1].strip() + allowed = self.parser.ua_blocker(ua) + if not allowed: + return self.build_response( + 403, "This UA has been blocked by the owner of this site." + ) + break + else: + return self.build_response( + 400, "You cannot connect without a User-Agent." + ) + method, path, version = self.parser.parse_request_line(request_line) + if not all([method, path, version]): + return self.build_response(400, "Bad Request") + # Figure out a better way to reload config if path == "/?pywebsrv_reload_conf=1": print("Got reload command! Reloading configuration...") - self.file_handler.base_dir = self.file_handler.read_config("directory") + self.file_handler = FileHandler() + self.parser = RequestParser() return self.build_response(302, "") - if not all([method, path, version]): - return self.build_response(400, "Bad Request") - if not self.parser.is_method_allowed( method ): @@ -397,6 +434,9 @@ class WebServer: # A really crude implementation of binary files. Later in 2.0 I'll actually # make this useful. mimetype = mimetype[0] + if mimetype is None: + # We have to assume it's binary. + return self.build_binary_response(200, file_content, "application/octet-stream") if "text/" not in mimetype: return self.build_binary_response(200, file_content, mimetype) @@ -410,7 +450,7 @@ class WebServer: 403: "Forbidden", 404: "Not Found", 405: "Method Not Allowed", - 500: "Internal Server Error", + 500: "Internal Server Error" } status_message = messages.get(status_code) headers = ( @@ -421,6 +461,7 @@ class WebServer: f"Connection: close\r\n\r\n" # Connection close is done because it is way easier to implement. # It's not like this program will see production use anyway. + # Tbh when i'll implement HTTP2 ) return headers.encode() + binary_data @@ -440,7 +481,7 @@ class WebServer: 405: "Method Not Allowed", 413: "Payload Too Large", 500: "Internal Server Error", - 635: "Go Away", + 635: "Go Away" } status_message = messages.get(status_code)