Преглед изворни кода

improve validation and request multipart form data

Daniel Sheffield пре 1 година
родитељ
комит
0c0def4389
3 измењених фајлова са 49 додато и 25 уклоњено
  1. 1 1
      app/rest/templates/clip-form.tpl
  2. 39 19
      app/rest/validate.py
  3. 9 5
      test/rest/test_url.py

+ 1 - 1
app/rest/templates/clip-form.tpl

@@ -1,7 +1,7 @@
 % from app.data.filter import get_query_param
 % content = setdefault("content", "") or ""
 % disabled = (setdefault("disabled", False) and 'readonly="true"') or ""
-<form id="paste" method="{{ method }}" action="{{ action }}">
+<form id="paste" method="{{ method }}" action="{{ action }}" enctype="multipart/form-data">
   <style>
 textarea::-webkit-scrollbar {
   width: 11px;

+ 39 - 19
app/rest/validate.py

@@ -23,16 +23,22 @@
 from io import BufferedReader
 from itertools import chain, zip_longest
 from bottle import static_file, HTTPError, abort, LocalRequest
-from urllib.parse import urlparse
+from urllib.parse import urlparse, quote, quote_plus, quote_from_bytes, urlencode
 from .hash_util import bytes_to_base32, blake
 
-MUST_ESCAPE = bytes([
+# according to rfc3696
+URL_MUST_ESCAPE = bytes([
     x for x in chain(
+        # control characters
         range(int('0x1F', 0)+1),
+        # 0x7F and non 7bit-ASCII
         range(int('0x7F', 0,), int('0xFF', 0)+1),
+        # specifically excluded
         b'@\\",[]'
     )
 ])
+# so give this list to urllib.parse.quote which follows rfc3986
+URL_SAFE = bytes(( i for i in range(int('0xff',0)+1) if i not in map(int, URL_MUST_ESCAPE) ))
 
 CLIP_SIZE_LIMIT = 65535
 def validate(filename: str) -> bytes:
@@ -63,12 +69,18 @@ def validate_parameter(request: LocalRequest, name: str) -> bytes:
     if content_length == -1:
         return abort(418, f"Content-Length must be specified")
     if content_length > CLIP_SIZE_LIMIT + OVERHEAD:
-        return abort(418, f"Content-Length can not exceed {CLIP_SIZE_LIMIT+OVERHEAD}")
+        return abort(418, f"Content-Length can not exceed {CLIP_SIZE_LIMIT*3} bytes")
 
+    print(request.content_type)
     # TODO: add test for both query/form param
-    content: bytes = (content or request.params[name]).encode('latin-1')
+    if 'multipart/form-data' in request.content_type:
+        # TODO: what about binary data ?
+        content: bytes = (content or request.params[name].encode('utf-8'))
+    else:
+        content: bytes = (content or request.params[name].encode('latin-1'))
+    
     if len(content) > CLIP_SIZE_LIMIT:
-        return abort(418, f"Paste can not exceed {CLIP_SIZE_LIMIT}")
+        return abort(418, f"Paste can not exceed {CLIP_SIZE_LIMIT} bytes")
     return content
 
 def validate_url(url: str) -> str:
@@ -80,19 +92,27 @@ def validate_url(url: str) -> str:
 
     if scheme in ('http', 'https') and not netloc: return abort(400, "HTTP(S) URL has no netloc")
 
-    encoded = url.encode('utf-8')
-    ret = []
-    for x in encoded:
-        if x in map(int, MUST_ESCAPE):
-            ret.append(f'%{hex(x)[2:].zfill(2)}'.upper())
-            continue
+    if netloc:
+        try:
+            user_info, loc = netloc.rsplit('@', 1)
+        except ValueError:
+            user_info = ''
+            loc = ''
+        if user_info:
+            user_info = quote(user_info, safe=URL_SAFE)
+            netloc = f"{user_info}@{''.join(loc)}"
         else:
-            ret.append(bytes([x]).decode('ascii'))
+            # TODO: do this properly, ie, valid dns-name/ip/port etc
+            netloc = quote(netloc, safe=URL_SAFE)
     
-    for idx, (c, *n) in enumerate(zip_longest(ret, ret[1:], ret[2:])):
-        if c == '%':
-            if None not in n and all([i.lower() in '0123456789abcdef' for i in n]):
-                continue
-            ret[idx] = '%25'
-
-    return ''.join(ret)
+    path = quote(path, safe=URL_SAFE)
+    params = quote_plus(params, safe=URL_SAFE)
+    query = quote(query, safe=URL_SAFE)
+    fragment = quote(fragment, safe=URL_SAFE)
+    
+    url = f'{scheme}://{netloc}{path}{params}'
+    if query:
+        url = f'{url}?{query}'
+    if fragment:
+        url = f'{url}#{fragment}'
+    return url

+ 9 - 5
test/rest/test_url.py

@@ -14,9 +14,13 @@ from app.rest.validate import validate_url
     ['file:///a/b/c',]*2,
     ['https://shandan.one',]*2,
     ['https://www.shandan.one',]*2,
-    ['https://www.shandan.one?',]*2,
     ['https://www.shandan.one/clip?id=123',]*2,
-    ['https://www.shandan.one/clip?id=123#',]*2,
+    
+    # empty query
+    ['https://www.shandan.one?', 'https://www.shandan.one',],
+
+    # empty fragment
+    ['https://www.shandan.one/clip?id=123#', 'https://www.shandan.one/clip?id=123',],
 
     # no double slash
     #['file:/a/b/c', (HTTPError, ""),],
@@ -39,15 +43,15 @@ from app.rest.validate import validate_url
     ['https://🌚.shandan.one', 'https://%F0%9F%8C%9A.shandan.one',],
 
     # @ in user_info not allowed
-    # TODO: check this
-    ['https://user@mail@www.shandan.one','https://user%40mail%40www.shandan.one'],
+    # TODO: check this - final @ should not be encoded ?
+    ['https://user@mail@www.shandan.one','https://user%40mail@www.shandan.one'],
 
     # delimiters
     # TODO: should < be translated to %3C ?
     ['https://www.shandan.one?a<b', 'https://www.shandan.one?a<b'],
 
     # more delimiters
-    ['https://www.shandan.one/clip?proportion=69%', 'https://www.shandan.one/clip?proportion=69%25'],
+    ['https://www.shandan.one/clip?proportion=69%', 'https://www.shandan.one/clip?proportion=69%'],
 
     # fragment before end of reference URI
     ['https://www.shandan.one/tiny#url?id=123', 'https://www.shandan.one/tiny#url?id=123'],