11.05.2020, 11:24
Try this:
Code:
local bit = require('bit')
local ssl = require('ssl')
local socket = require('socket')
local encdec = require('encdec')
local parse_url = require('socket.url').parse
local bxor = bit.bxor
local bor = bit.bor
local band = bit.band
local lshift = bit.lshift
local rshift = bit.rshift
local ssub = string.sub
local sbyte = string.byte
local schar = string.char
local tinsert = table.insert
local tconcat = table.concat
local mmin = math.min
local mfloor = math.floor
local mrandom = math.random
local base64enc = encdec.base64enc
local sha1 = encdec.sha1
local unpack = unpack
local CONTINUATION = 0
local TEXT = 1
local BINARY = 2
local CLOSE = 8
local PING = 9
local PONG = 10
local guid = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
local read_n_bytes = function(str, pos, n)
pos = pos or 1
return pos+n, string.byte(str, pos, pos + n - 1)
end
local read_int8 = function(str, pos)
return read_n_bytes(str, pos, 1)
end
local read_int16 = function(str, pos)
local new_pos,a,b = read_n_bytes(str, pos, 2)
return new_pos, lshift(a, 8) + b
end
local read_int32 = function(str, pos)
local new_pos,a,b,c,d = read_n_bytes(str, pos, 4)
return new_pos,
lshift(a, 24) +
lshift(b, 16) +
lshift(c, 8 ) +
d
end
local write_int8 = schar
local write_int16 = function(v)
return schar(rshift(v, 8), band(v, 0xFF))
end
local write_int32 = function(v)
return schar(
band(rshift(v, 24), 0xFF),
band(rshift(v, 16), 0xFF),
band(rshift(v, 8), 0xFF),
band(v, 0xFF)
)
end
local generate_key = function()
math.randomseed(os.time())
local r1 = mrandom(0,0xfffffff)
local r2 = mrandom(0,0xfffffff)
local r3 = mrandom(0,0xfffffff)
local r4 = mrandom(0,0xfffffff)
local key = write_int32(r1)..write_int32(r2)..write_int32(r3)..write_int32(r4)
return base64enc(key)
end
local bits = function(...)
local n = 0
for _,bitn in pairs{...} do
n = n + 2^bitn
end
return n
end
local bit_7 = bits(7)
local bit_0_3 = bits(0,1,2,3)
local bit_0_6 = bits(0,1,2,3,4,5,6)
-- TODO: improve performance
local xor_mask = function(encoded,mask,payload)
local transformed,transformed_arr = {},{}
-- xor chunk-wise to prevent stack overflow.
-- sbyte and schar multiple in/out values
-- which require stack
for p=1,payload,2000 do
local last = mmin(p+1999,payload)
local original = {sbyte(encoded,p,last)}
for i=1,#original do
local j = (i-1) % 4 + 1
transformed[i] = bxor(original[i],mask[j])
end
local xored = schar(unpack(transformed,1,#original))
tinsert(transformed_arr,xored)
end
return tconcat(transformed_arr)
end
local encode_header_small = function(header, payload)
return schar(header, payload)
end
local encode_header_medium = function(header, payload, len)
return schar(header, payload, band(rshift(len, 8), 0xFF), band(len, 0xFF))
end
local encode_header_big = function(header, payload, high, low)
return schar(header, payload)..write_int32(high)..write_int32(low)
end
local encode = function(data,opcode,masked,fin)
local header = opcode or 1-- TEXT is default opcode
if fin == nil or fin == true then
header = bor(header,bit_7)
end
local payload = 0
if masked then
payload = bor(payload,bit_7)
end
local len = #data
local chunks = {}
if len < 126 then
payload = bor(payload,len)
tinsert(chunks,encode_header_small(header,payload))
elseif len <= 0xffff then
payload = bor(payload,126)
tinsert(chunks,encode_header_medium(header,payload,len))
elseif len < 2^53 then
local high = mfloor(len/2^32)
local low = len - high*2^32
payload = bor(payload,127)
tinsert(chunks,encode_header_big(header,payload,high,low))
end
if not masked then
tinsert(chunks,data)
else
local m1 = mrandom(0,0xff)
local m2 = mrandom(0,0xff)
local m3 = mrandom(0,0xff)
local m4 = mrandom(0,0xff)
local mask = {m1,m2,m3,m4}
tinsert(chunks,write_int8(m1,m2,m3,m4))
tinsert(chunks,xor_mask(data,mask,#data))
end
return tconcat(chunks)
end
local decode = function(encoded)
local encoded_bak = encoded
if #encoded < 2 then
return nil,2-#encoded
end
local pos,header,payload
pos,header = read_int8(encoded,1)
pos,payload = read_int8(encoded,pos)
local high,low
encoded = ssub(encoded,pos)
local bytes = 2
local fin = band(header,bit_7) > 0
local opcode = band(header,bit_0_3)
local mask = band(payload,bit_7) > 0
payload = band(payload,bit_0_6)
if payload > 125 then
if payload == 126 then
if #encoded < 2 then
return nil,2-#encoded
end
pos,payload = read_int16(encoded,1)
elseif payload == 127 then
if #encoded < 8 then
return nil,8-#encoded
end
pos,high = read_int32(encoded,1)
pos,low = read_int32(encoded,pos)
payload = high*2^32 + low
if payload < 0xffff or payload > 2^53 then
assert(false,'INVALID PAYLOAD '..payload)
end
else
assert(false,'INVALID PAYLOAD '..payload)
end
encoded = ssub(encoded,pos)
bytes = bytes + pos - 1
end
local decoded
if mask then
local bytes_short = payload + 4 - #encoded
if bytes_short > 0 then
return nil,bytes_short
end
local m1,m2,m3,m4
pos,m1 = read_int8(encoded,1)
pos,m2 = read_int8(encoded,pos)
pos,m3 = read_int8(encoded,pos)
pos,m4 = read_int8(encoded,pos)
encoded = ssub(encoded,pos)
local mask = {
m1,m2,m3,m4
}
decoded = xor_mask(encoded,mask,payload)
bytes = bytes + 4 + payload
else
local bytes_short = payload - #encoded
if bytes_short > 0 then
return nil,bytes_short
end
if #encoded > payload then
decoded = ssub(encoded,1,payload)
else
decoded = encoded
end
bytes = bytes + payload
end
return decoded,fin,opcode,encoded_bak:sub(bytes+1),mask
end
local encode_close = function(code,reason)
if code then
local data = write_int16(code)
if reason then
data = data..tostring(reason)
end
return data
end
return ''
end
local decode_close = function(data)
local _,code,reason
if data then
if #data > 1 then
_,code = read_int16(data,1)
end
if #data > 2 then
reason = data:sub(3)
end
end
return code,reason
end
local sec_websocket_accept = function(sec_websocket_key)
local enc = sha1(sec_websocket_key..guid, true)
return base64enc(enc)
end
local http_headers = function(request)
local headers = {}
if not request:match('.*HTTP/1%.1') then
return headers
end
request = request:match('[^\r\n]+\r\n(.*)')
for line in request:gmatch('[^\r\n]*\r\n') do
local name,val = line:match('([^%s]+)%s*:%s*([^\r\n]+)')
if name and val then
name = name:lower()
if not name:match('sec%-websocket') then
val = val:lower()
end
if not headers[name] then
headers[name] = val
else
headers[name] = headers[name]..','..val
end
elseif line ~= '\r\n' then
assert(false,line..'('..#line..')')
end
end
return headers,request:match('\r\n\r\n(.*)')
end
local upgrade_request = function(req, key, protocol)
local format = string.format
local lines = {
format('GET %s HTTP/1.1',req.path or ''),
format('Host: %s',req.host),
'Upgrade: websocket',
'Connection: Upgrade',
format('Sec-WebSocket-Key: %s',key),
'Sec-WebSocket-Version: 13',
}
if protocol then
tinsert(lines, format('Sec-WebSocket-Protocol: %s', protocol))
end
if req.port and req.port ~= 80 then
lines[2] = format('Host: %s:%d',req.host,req.port)
end
if req.userinfo then
local auth = format('Authorization: Basic %s', base64enc(req.userinfo))
tinsert(lines, auth)
end
tinsert(lines,'\r\n')
return tconcat(lines,'\r\n')
end
local receive = function(self)
if self.state ~= 'OPEN' and not self.is_closing then
return nil,nil,false,1006,'wrong state'
end
local first_opcode
local frames
local bytes = 3
local encoded = ''
local clean = function(was_clean,code,reason)
self.state = 'CLOSED'
self:sock_close()
if self.on_close then
self:on_close()
end
return nil,nil,was_clean,code,reason or 'closed'
end
while true do
local chunk,err = self:sock_receive(bytes)
if err then
if err == 'timeout' then
return nil,nil,false,1006,err
else
return clean(false,1006,err)
end
end
encoded = encoded..chunk
local decoded,fin,opcode,_,masked = decode(encoded)
if masked then
return clean(false,1006,'Websocket receive failed: frame was not masked')
end
if decoded then
if opcode == CLOSE then
if not self.is_closing then
local code,reason = decode_close(decoded)
-- echo code
local msg = encode_close(code)
local encoded = encode(msg,CLOSE,true)
local n,err = self:sock_send(encoded)
if n == #encoded then
return clean(true,code,reason)
else
return clean(false,code,err)
end
else
return decoded,opcode
end
end
if not first_opcode then
first_opcode = opcode
end
if not fin then
if not frames then
frames = {}
elseif opcode ~= CONTINUATION then
return clean(false,1002,'protocol error')
end
bytes = 3
encoded = ''
tinsert(frames,decoded)
elseif not frames then
return decoded,first_opcode
else
tinsert(frames,decoded)
return tconcat(frames),first_opcode
end
else
assert(type(fin) == 'number' and fin > 0)
bytes = fin
end
end
end
local send = function(self,data,opcode)
if self.state ~= 'OPEN' then
return nil,false,1006,'wrong state'
end
local encoded = encode(data,opcode or TEXT,true)
local n,err = self:sock_send(encoded)
if n ~= #encoded then
return nil,self:close(1006,err)
end
return true
end
local close = function(self,code,reason)
if self.state ~= 'OPEN' then
return false,1006,'wrong state'
end
if self.state == 'CLOSED' then
return false,1006,'wrong state'
end
local msg = encode_close(code or 1000,reason)
local encoded = encode(msg,CLOSE,true)
local n,err = self:sock_send(encoded)
local was_clean = false
code = 1005
reason = ''
if n == #encoded then
self.is_closing = true
local rmsg,opcode = self:receive()
if rmsg and opcode == CLOSE then
code,reason = decode_close(rmsg)
was_clean = true
end
else
reason = err
end
self:sock_close()
if self.on_close then
self:on_close()
end
self.state = 'CLOSED'
return was_clean,code,reason or ''
end
local DEFAULT_PORTS = {ws = 80, wss = 443}
local connect = function(self,ws_url,ssl_params)
if self.state ~= 'CLOSED' then
return nil,'wrong state',nil
end
local parsed = parse_url(ws_url)
if parsed.scheme ~= 'wss' and parsed.scheme ~= 'ws' then
return nil, 'bad protocol'
end
if not parsed.port then
parsed.port = DEFAULT_PORTS[ parsed.scheme ]
end
-- Preconnect (for SSL if needed)
local _,err = self:sock_connect(parsed.host, parsed.port)
if err then
return nil,err,nil
end
if parsed.scheme == 'wss' then
if type(ssl_params) ~= 'table' then
ssl_params = {
protocol = 'tlsv1',
options = {'all', 'no_sslv2', 'no_sslv3'},
verify = 'none',
}
end
ssl_params.mode = 'client'
self.sock = ssl.wrap(self.sock, ssl_params)
self.sock:dohandshake()
elseif parsed.scheme ~= 'ws' then
return nil, 'bad protocol'
end
local key = generate_key()
local req = upgrade_request(parsed, key, self.protocol)
local n,err = self:sock_send(req)
if n ~= #req then
return nil,err,nil
end
local resp = {}
repeat
local line,err = self:sock_receive('*l')
resp[#resp+1] = line
if err then
return nil,err,nil
end
until line == ''
local response = tconcat(resp,'\r\n')
local headers = http_headers(response)
local expected_accept = sec_websocket_accept(key)
if headers['sec-websocket-accept'] ~= expected_accept then
local msg = 'Websocket Handshake failed: Invalid Sec-Websocket-Accept (expected %s got %s)'
return nil,msg:format(expected_accept,headers['sec-websocket-accept'] or 'nil'),headers
end
self.state = 'OPEN'
return true,headers['sec-websocket-protocol'],headers
end
local extend = function(obj)
obj.state = 'CLOSED'
obj.receive = receive
obj.send = send
obj.close = close
obj.connect = connect
return obj
end
local client_copas = function(timeout)
local copas = require('copas')
local self = {}
self.sock_connect = function(self,host,port)
self.sock = socket.tcp()
self.sock:settimeout(timeout or 5)
local _,err = copas.connect(self.sock,host,port)
if err and err ~= 'already connected' then
self.sock:close()
return nil,err
end
end
self.sock_send = function(self,...)
return copas.send(self.sock,...)
end
self.sock_receive = function(self,...)
return copas.receive(self.sock,...)
end
self.sock_close = function(self)
self.sock:close()
end
self = extend(self)
return self
end
local client_sync = function(timeout)
local self = {}
self.sock_connect = function(self,host,port)
self.sock = socket.tcp()
self.sock:settimeout(timeout or 5)
local _,err = self.sock:connect(host,port)
if err then
self.sock:close()
return nil,err
end
end
self.sock_send = function(self,...)
return self.sock:send(...)
end
self.sock_receive = function(self,...)
return self.sock:receive(...)
end
self.sock_close = function(self)
self.sock:close()
end
self = extend(self)
return self
end
local client = function(mode, timeout)
if mode == 'copas' then
return client_copas(timeout)
else
return client_sync(timeout)
end
end
return {
client = client,
CONTINUATION = CONTINUATION,
TEXT = TEXT,
BINARY = BINARY,
CLOSE = CLOSE,
PING = PING,
PONG = PONG
}