322 lines
7.4 KiB
C++
322 lines
7.4 KiB
C++
//#define DEBUGGING
|
|
|
|
#include "global.h"
|
|
#include "WebSocketClient.h"
|
|
|
|
#include "sha1.h"
|
|
#include "base64.h"
|
|
|
|
|
|
bool WebSocketClient::handshake(Client &client) {
|
|
|
|
socket_client = &client;
|
|
|
|
// If there is a connected client->
|
|
if (socket_client->connected()) {
|
|
// Check request and look for websocket handshake
|
|
#ifdef DEBUGGING
|
|
Serial.println(F("Client connected"));
|
|
#endif
|
|
if (analyzeRequest()) {
|
|
#ifdef DEBUGGING
|
|
Serial.println(F("Websocket established"));
|
|
#endif
|
|
|
|
return true;
|
|
|
|
} else {
|
|
// Might just need to break until out of socket_client loop.
|
|
#ifdef DEBUGGING
|
|
Serial.println(F("Invalid handshake"));
|
|
#endif
|
|
disconnectStream();
|
|
|
|
return false;
|
|
}
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool WebSocketClient::analyzeRequest() {
|
|
String temp;
|
|
|
|
int bite;
|
|
bool foundupgrade = false;
|
|
unsigned long intkey[2];
|
|
String serverKey;
|
|
char keyStart[17];
|
|
char b64Key[25];
|
|
String key = "------------------------";
|
|
|
|
randomSeed(analogRead(0));
|
|
|
|
for (int i=0; i<16; ++i) {
|
|
keyStart[i] = (char)random(1, 256);
|
|
}
|
|
|
|
base64_encode(b64Key, keyStart, 16);
|
|
|
|
for (int i=0; i<24; ++i) {
|
|
key[i] = b64Key[i];
|
|
}
|
|
|
|
#ifdef DEBUGGING
|
|
Serial.println(F("Sending websocket upgrade headers"));
|
|
#endif
|
|
|
|
socket_client->print(F("GET "));
|
|
socket_client->print(path);
|
|
socket_client->print(F(" HTTP/1.1\r\n"));
|
|
socket_client->print(F("Upgrade: websocket\r\n"));
|
|
socket_client->print(F("Connection: Upgrade\r\n"));
|
|
socket_client->print(F("Host: "));
|
|
socket_client->print(host);
|
|
socket_client->print(CRLF);
|
|
socket_client->print(F("Sec-WebSocket-Key: "));
|
|
socket_client->print(key);
|
|
socket_client->print(CRLF);
|
|
socket_client->print(F("Sec-WebSocket-Protocol: "));
|
|
socket_client->print(protocol);
|
|
socket_client->print(CRLF);
|
|
socket_client->print(F("Sec-WebSocket-Version: 13\r\n"));
|
|
socket_client->print(CRLF);
|
|
|
|
#ifdef DEBUGGING
|
|
Serial.println(F("Analyzing response headers"));
|
|
#endif
|
|
|
|
while (socket_client->connected() && !socket_client->available()) {
|
|
delay(100);
|
|
Serial.println("Waiting...");
|
|
}
|
|
|
|
// TODO: More robust string extraction
|
|
while ((bite = socket_client->read()) != -1) {
|
|
|
|
temp += (char)bite;
|
|
|
|
if ((char)bite == '\n') {
|
|
#ifdef DEBUGGING
|
|
Serial.print("Got Header: " + temp);
|
|
#endif
|
|
if (!foundupgrade && temp.startsWith("Upgrade: websocket")) {
|
|
foundupgrade = true;
|
|
} else if (temp.startsWith("Sec-WebSocket-Accept: ")) {
|
|
serverKey = temp.substring(22,temp.length() - 2); // Don't save last CR+LF
|
|
}
|
|
temp = "";
|
|
}
|
|
|
|
if (!socket_client->available()) {
|
|
delay(20);
|
|
}
|
|
}
|
|
|
|
key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
|
uint8_t *hash;
|
|
char result[21];
|
|
char b64Result[30];
|
|
|
|
Sha1.init();
|
|
Sha1.print(key);
|
|
hash = Sha1.result();
|
|
|
|
for (int i=0; i<20; ++i) {
|
|
result[i] = (char)hash[i];
|
|
}
|
|
result[20] = '\0';
|
|
|
|
base64_encode(b64Result, result, 20);
|
|
|
|
// if the keys match, good to go
|
|
return serverKey.equals(String(b64Result));
|
|
}
|
|
|
|
|
|
bool WebSocketClient::handleStream(String& data, uint8_t *opcode) {
|
|
uint8_t msgtype;
|
|
uint8_t bite;
|
|
unsigned int length;
|
|
uint8_t mask[4];
|
|
uint8_t index;
|
|
unsigned int i;
|
|
bool hasMask = false;
|
|
|
|
if (!socket_client->connected() || !socket_client->available())
|
|
{
|
|
return false;
|
|
}
|
|
|
|
msgtype = timedRead();
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
|
|
length = timedRead();
|
|
|
|
if (length & WS_MASK) {
|
|
hasMask = true;
|
|
length = length & ~WS_MASK;
|
|
}
|
|
|
|
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
|
|
index = 6;
|
|
|
|
if (length == WS_SIZE16) {
|
|
length = timedRead() << 8;
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
|
|
length |= timedRead();
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
|
|
} else if (length == WS_SIZE64) {
|
|
#ifdef DEBUGGING
|
|
Serial.println(F("No support for over 16 bit sized messages"));
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
if (hasMask) {
|
|
// get the mask
|
|
mask[0] = timedRead();
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
|
|
mask[1] = timedRead();
|
|
if (!socket_client->connected()) {
|
|
|
|
return false;
|
|
}
|
|
|
|
mask[2] = timedRead();
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
|
|
mask[3] = timedRead();
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
data = "";
|
|
|
|
if (opcode != NULL)
|
|
{
|
|
*opcode = msgtype & ~WS_FIN;
|
|
}
|
|
|
|
if (hasMask) {
|
|
for (i=0; i<length; ++i) {
|
|
data += (char) (timedRead() ^ mask[i % 4]);
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
}
|
|
} else {
|
|
for (i=0; i<length; ++i) {
|
|
data += (char) timedRead();
|
|
if (!socket_client->connected()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void WebSocketClient::disconnectStream() {
|
|
#ifdef DEBUGGING
|
|
Serial.println(F("Terminating socket"));
|
|
#endif
|
|
// Should send 0x8700 to server to tell it I'm quitting here.
|
|
socket_client->write((uint8_t) 0x87);
|
|
socket_client->write((uint8_t) 0x00);
|
|
|
|
socket_client->flush();
|
|
delay(10);
|
|
socket_client->stop();
|
|
}
|
|
|
|
bool WebSocketClient::getData(String& data, uint8_t *opcode) {
|
|
return handleStream(data, opcode);
|
|
}
|
|
|
|
void WebSocketClient::sendData(const char *str, uint8_t opcode) {
|
|
#ifdef DEBUGGING
|
|
Serial.print(F("Sending data: "));
|
|
Serial.println(str);
|
|
#endif
|
|
if (socket_client->connected()) {
|
|
sendEncodedData(str, opcode);
|
|
}
|
|
}
|
|
|
|
void WebSocketClient::sendData(String str, uint8_t opcode) {
|
|
#ifdef DEBUGGING
|
|
Serial.print(F("Sending data: "));
|
|
Serial.println(str);
|
|
#endif
|
|
if (socket_client->connected()) {
|
|
sendEncodedData(str, opcode);
|
|
}
|
|
}
|
|
|
|
int WebSocketClient::timedRead() {
|
|
while (!socket_client->available()) {
|
|
delay(20);
|
|
}
|
|
|
|
return socket_client->read();
|
|
}
|
|
|
|
void WebSocketClient::sendEncodedData(char *str, uint8_t opcode) {
|
|
uint8_t mask[4];
|
|
int size = strlen(str);
|
|
|
|
// Opcode; final fragment
|
|
socket_client->write(opcode | WS_FIN);
|
|
|
|
// NOTE: no support for > 16-bit sized messages
|
|
if (size > 125) {
|
|
socket_client->write(WS_SIZE16 | WS_MASK);
|
|
socket_client->write((uint8_t) (size >> 8));
|
|
socket_client->write((uint8_t) (size & 0xFF));
|
|
} else {
|
|
socket_client->write((uint8_t) size | WS_MASK);
|
|
}
|
|
|
|
mask[0] = random(0, 256);
|
|
mask[1] = random(0, 256);
|
|
mask[2] = random(0, 256);
|
|
mask[3] = random(0, 256);
|
|
|
|
socket_client->write(mask[0]);
|
|
socket_client->write(mask[1]);
|
|
socket_client->write(mask[2]);
|
|
socket_client->write(mask[3]);
|
|
|
|
for (int i=0; i<size; ++i) {
|
|
socket_client->write(str[i] ^ mask[i % 4]);
|
|
}
|
|
}
|
|
|
|
void WebSocketClient::sendEncodedData(String str, uint8_t opcode) {
|
|
int size = str.length() + 1;
|
|
char cstr[size];
|
|
|
|
str.toCharArray(cstr, size);
|
|
|
|
sendEncodedData(cstr, opcode);
|
|
}
|