1 import os
2 import sys
3 import ssl
4 import pprint
5 import threading
6 import urllib.parse
7 # Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
8 from http.server import (HTTPServer as _HTTPServer,
9     SimpleHTTPRequestHandler, BaseHTTPRequestHandler)
10 
11 from test import support
12 from test.support import socket_helper
13 
14 here = os.path.dirname(__file__)
15 
16 HOST = socket_helper.HOST
17 CERTFILE = os.path.join(here, 'keycert.pem')
18 
19 # This one's based on HTTPServer, which is based on socketserver
20 
21 class HTTPSServer(_HTTPServer):
22 
23     def __init__(self, server_address, handler_class, context):
24         _HTTPServer.__init__(self, server_address, handler_class)
25         self.context = context
26 
27     def __str__(self):
28         return ('<%s %s:%s>' %
29                 (self.__class__.__name__,
30                  self.server_name,
31                  self.server_port))
32 
33     def get_request(self):
34         # override this to wrap socket with SSL
35         try:
36             sock, addr = self.socket.accept()
37             sslconn = self.context.wrap_socket(sock, server_side=True)
38         except OSError as e:
39             # socket errors are silenced by the caller, print them here
40             if support.verbose:
41                 sys.stderr.write("Got an error:\n%s\n" % e)
42             raise
43         return sslconn, addr
44 
45 class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
46     # need to override translate_path to get a known root,
47     # instead of using os.curdir, since the test could be
48     # run from anywhere
49 
50     server_version = "TestHTTPS/1.0"
51     root = here
52     # Avoid hanging when a request gets interrupted by the client
53     timeout = support.LOOPBACK_TIMEOUT
54 
55     def translate_path(self, path):
56         """Translate a /-separated PATH to the local filename syntax.
57 
58         Components that mean special things to the local file system
59         (e.g. drive or directory names) are ignored.  (XXX They should
60         probably be diagnosed.)
61 
62         """
63         # abandon query parameters
64         path = urllib.parse.urlparse(path)[2]
65         path = os.path.normpath(urllib.parse.unquote(path))
66         words = path.split('/')
67         words = filter(None, words)
68         path = self.root
69         for word in words:
70             drive, word = os.path.splitdrive(word)
71             head, word = os.path.split(word)
72             path = os.path.join(path, word)
73         return path
74 
75     def log_message(self, format, *args):
76         # we override this to suppress logging unless "verbose"
77         if support.verbose:
78             sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
79                              (self.server.server_address,
80                               self.server.server_port,
81                               self.request.cipher(),
82                               self.log_date_time_string(),
83                               format%args))
84 
85 
86 class StatsRequestHandler(BaseHTTPRequestHandler):
87     """Example HTTP request handler which returns SSL statistics on GET
88     requests.
89     """
90 
91     server_version = "StatsHTTPS/1.0"
92 
93     def do_GET(self, send_body=True):
94         """Serve a GET request."""
95         sock = self.rfile.raw._sock
96         context = sock.context
97         stats = {
98             'session_cache': context.session_stats(),
99             'cipher': sock.cipher(),
100             'compression': sock.compression(),
101             }
102         body = pprint.pformat(stats)
103         body = body.encode('utf-8')
104         self.send_response(200)
105         self.send_header("Content-type", "text/plain; charset=utf-8")
106         self.send_header("Content-Length", str(len(body)))
107         self.end_headers()
108         if send_body:
109             self.wfile.write(body)
110 
111     def do_HEAD(self):
112         """Serve a HEAD request."""
113         self.do_GET(send_body=False)
114 
115     def log_request(self, format, *args):
116         if support.verbose:
117             BaseHTTPRequestHandler.log_request(self, format, *args)
118 
119 
120 class HTTPSServerThread(threading.Thread):
121 
122     def __init__(self, context, host=HOST, handler_class=None):
123         self.flag = None
124         self.server = HTTPSServer((host, 0),
125                                   handler_class or RootedHTTPRequestHandler,
126                                   context)
127         self.port = self.server.server_port
128         threading.Thread.__init__(self)
129         self.daemon = True
130 
131     def __str__(self):
132         return "<%s %s>" % (self.__class__.__name__, self.server)
133 
134     def start(self, flag=None):
135         self.flag = flag
136         threading.Thread.start(self)
137 
138     def run(self):
139         if self.flag:
140             self.flag.set()
141         try:
142             self.server.serve_forever(0.05)
143         finally:
144             self.server.server_close()
145 
146     def stop(self):
147         self.server.shutdown()
148 
149 
150 def make_https_server(case, *, context=None, certfile=CERTFILE,
151                       host=HOST, handler_class=None):
152     if context is None:
153         context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
154     # We assume the certfile contains both private key and certificate
155     context.load_cert_chain(certfile)
156     server = HTTPSServerThread(context, host, handler_class)
157     flag = threading.Event()
158     server.start(flag)
159     flag.wait()
160     def cleanup():
161         if support.verbose:
162             sys.stdout.write('stopping HTTPS server\n')
163         server.stop()
164         if support.verbose:
165             sys.stdout.write('joining HTTPS thread\n')
166         server.join()
167     case.addCleanup(cleanup)
168     return server
169 
170 
171 if __name__ == "__main__":
172     import argparse
173     parser = argparse.ArgumentParser(
174         description='Run a test HTTPS server. '
175                     'By default, the current directory is served.')
176     parser.add_argument('-p', '--port', type=int, default=4433,
177                         help='port to listen on (default: %(default)s)')
178     parser.add_argument('-q', '--quiet', dest='verbose', default=True,
179                         action='store_false', help='be less verbose')
180     parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
181                         action='store_true', help='always return stats page')
182     parser.add_argument('--curve-name', dest='curve_name', type=str,
183                         action='store',
184                         help='curve name for EC-based Diffie-Hellman')
185     parser.add_argument('--ciphers', dest='ciphers', type=str,
186                         help='allowed cipher list')
187     parser.add_argument('--dh', dest='dh_file', type=str, action='store',
188                         help='PEM file containing DH parameters')
189     args = parser.parse_args()
190 
191     support.verbose = args.verbose
192     if args.use_stats_handler:
193         handler_class = StatsRequestHandler
194     else:
195         handler_class = RootedHTTPRequestHandler
196         handler_class.root = os.getcwd()
197     context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
198     context.load_cert_chain(CERTFILE)
199     if args.curve_name:
200         context.set_ecdh_curve(args.curve_name)
201     if args.dh_file:
202         context.load_dh_params(args.dh_file)
203     if args.ciphers:
204         context.set_ciphers(args.ciphers)
205 
206     server = HTTPSServer(("", args.port), handler_class, context)
207     if args.verbose:
208         print("Listening on https://localhost:{0.port}".format(args))
209     server.serve_forever(0.1)
210