8 #include <botan/tls_client.h>
9 #include <botan/internal/tls_alerts.h>
10 #include <botan/internal/tls_state.h>
11 #include <botan/loadstor.h>
12 #include <botan/rsa.h>
13 #include <botan/dsa.h>
26 class State_Transition_Error :
public Unexpected_Message
29 State_Transition_Error(
const std::string& err) :
30 Unexpected_Message(
"State transition error from " + err) {}
35 if(state->client_hello)
36 throw State_Transition_Error(
"HelloRequest");
40 if(!state->client_hello || state->server_hello)
41 throw State_Transition_Error(
"ServerHello");
45 if(!state->server_hello || state->server_kex ||
46 state->cert_req || state->server_hello_done)
47 throw State_Transition_Error(
"ServerCertificate");
51 if(!state->server_hello || state->server_kex ||
52 state->cert_req || state->server_hello_done)
53 throw State_Transition_Error(
"ServerKeyExchange");
57 if(!state->server_certs || state->cert_req || state->server_hello_done)
58 throw State_Transition_Error(
"CertificateRequest");
62 if(!state->server_hello || state->server_hello_done)
63 throw State_Transition_Error(
"ServerHelloDone");
67 if(!state->client_finished || state->server_finished)
68 throw State_Transition_Error(
"ServerChangeCipherSpec");
72 if(!state->got_server_ccs)
73 throw State_Transition_Error(
"ServerFinished");
76 throw Unexpected_Message(
"Unexpected message in handshake");
85 std::tr1::function<
void (
const byte[],
size_t)> output_fn,
99 certs.push_back(std::make_pair(cert, cert_key));
108 for(
size_t i = 0; i != certs.size(); i++)
109 delete certs[i].second;
116 void TLS_Client::initialize()
118 std::string error_str;
127 catch(TLS_Exception& e)
129 error_str = e.what();
130 error_type = e.type();
132 catch(std::exception& e)
134 error_str = e.what();
155 throw Stream_IO_Error(
"TLS_Client: Handshake failed: " + error_str);
174 "TLS_Client::write called while closed");
189 while(read_buf.
size() == 0)
196 size_t got = std::min<size_t>(read_buf.
size(), length);
197 read_buf.
read(out, got);
227 writer.
alert(level, alert_code);
239 void TLS_Client::state_machine()
242 SecureVector<byte> record(1024);
244 size_t bytes_needed = reader.
get_record(rec_type, record);
248 size_t to_get = std::min<size_t>(record.size(), bytes_needed);
249 size_t got = input_fn(&record[0], to_get);
259 bytes_needed = reader.
get_record(rec_type, record);
271 read_buf.
write(&record[0], record.size());
273 throw Unexpected_Message(
"Application data before handshake done");
276 read_handshake(rec_type, record);
277 else if(rec_type ==
ALERT)
297 throw Unexpected_Message(
"Unknown message type received");
303 void TLS_Client::read_handshake(
byte rec_type,
304 const MemoryRegion<byte>& rec_buf)
311 state->queue.write(&rec_buf[0], rec_buf.size());
317 SecureVector<byte> contents;
321 if(state->queue.size() >= 4)
323 byte head[4] = { 0 };
324 state->queue.peek(head, 4);
326 const size_t length =
make_u32bit(0, head[1], head[2], head[3]);
328 if(state->queue.size() >= length + 4)
331 contents.resize(length);
332 state->queue.read(head, 4);
333 state->queue.read(&contents[0], contents.size());
339 if(state->queue.size() == 0 && rec_buf.size() == 1 && rec_buf[0] == 1)
342 throw Decoding_Error(
"Malformed ChangeCipherSpec message");
345 throw Decoding_Error(
"Unknown message type in handshake processing");
350 process_handshake_msg(type, contents);
361 const MemoryRegion<byte>& contents)
367 throw Unexpected_Message(
"Unexpected handshake message");
371 state->hash.update(static_cast<byte>(type));
372 const size_t record_length = contents.size();
373 for(
size_t i = 0; i != 3; i++)
374 state->hash.update(get_byte<u32bit>(i+1, record_length));
375 state->hash.update(contents);
380 client_check_state(type, state);
382 state->server_hello =
new Server_Hello(contents);
384 if(!state->client_hello->offered_suite(
385 state->server_hello->ciphersuite()
389 "TLS_Client: Server replied with bad ciphersuite");
391 state->version = state->server_hello->version();
393 if(state->version > state->client_hello->version())
395 "TLS_Client: Server replied with bad version");
399 "TLS_Client: Server is too old for specified policy");
404 state->suite = CipherSuite(state->server_hello->ciphersuite());
408 client_check_state(type, state);
411 throw Unexpected_Message(
"Recived certificate from anonymous server");
413 state->server_certs =
new Certificate(contents);
415 peer_certs = state->server_certs->cert_chain();
416 if(peer_certs.size() == 0)
418 "TLS_Client: No certificates sent by server");
422 "TLS_Client: Server certificate is not valid");
424 state->kex_pub = peer_certs[0].subject_public_key();
426 bool is_dsa =
false, is_rsa =
false;
428 if(dynamic_cast<DSA_PublicKey*>(state->kex_pub))
430 else if(dynamic_cast<RSA_PublicKey*>(state->kex_pub))
434 "Unknown key type received in server kex");
439 "Certificate key type did not match ciphersuite");
443 client_check_state(type, state);
446 throw Unexpected_Message(
"Unexpected key exchange from server");
448 state->server_kex =
new Server_Key_Exchange(contents);
451 delete state->kex_pub;
453 state->kex_pub = state->server_kex->key();
455 bool is_dh =
false, is_rsa =
false;
457 if(dynamic_cast<DH_PublicKey*>(state->kex_pub))
459 else if(dynamic_cast<RSA_PublicKey*>(state->kex_pub))
463 "Unknown key type received in server kex");
468 "Certificate key type did not match ciphersuite");
472 if(!state->server_kex->verify(peer_certs[0],
473 state->client_hello->random(),
474 state->server_hello->random()))
476 "Bad signature on server key exchange");
481 client_check_state(type, state);
483 state->cert_req =
new Certificate_Req(contents);
484 state->do_client_auth =
true;
488 client_check_state(type, state);
490 state->server_hello_done =
new Server_Hello_Done(contents);
492 if(state->do_client_auth)
494 std::vector<X509_Certificate> send_certs;
496 std::vector<Certificate_Type> types =
497 state->cert_req->acceptable_types();
500 state->client_certs =
new Certificate(writer, send_certs,
505 new Client_Key_Exchange(rng, writer, state->hash,
506 state->kex_pub, state->version,
507 state->client_hello->version());
509 if(state->do_client_auth)
511 Private_Key* key_matching_cert = 0;
512 state->client_verify =
new Certificate_Verify(rng,
517 state->keys = SessionKeys(state->suite, state->version,
518 state->client_kex->pre_master_secret(),
519 state->client_hello->random(),
520 state->server_hello->random());
527 state->client_finished =
new Finished(writer, state->version,
CLIENT,
528 state->keys.master_secret(),
533 client_check_state(type, state);
536 state->got_server_ccs =
true;
540 client_check_state(type, state);
542 state->server_finished =
new Finished(contents);
544 if(!state->server_finished->verify(state->keys.master_secret(),
545 state->version, state->hash,
SERVER))
547 "Finished message didn't verify");
554 throw Unexpected_Message(
"Unknown handshake message received");
560 void TLS_Client::do_handshake()
562 state =
new Handshake_State;
564 state->client_hello =
new Client_Hello(rng, writer, policy, state->hash);
570 if(!active && !state)
571 throw TLS_Exception(
HANDSHAKE_FAILURE,
"TLS_Client: Handshake failed (do_handshake)");
virtual Version_Code pref_version() const
size_t get_record(byte &msg_type, MemoryRegion< byte > &buffer)
size_t read(byte[], size_t)
virtual Version_Code min_version() const
virtual bool check_cert(const std::vector< X509_Certificate > &cert_chain) const =0
void add_input(const byte input[], size_t input_size)
void set_version(Version_Code)
void set_keys(const CipherSuite &, const SessionKeys &, Connection_Side)
void set_version(Version_Code version)
TLS_Client(std::tr1::function< size_t(byte[], size_t)> input_fn, std::tr1::function< void(const byte[], size_t)> output_fn, const TLS_Policy &policy, RandomNumberGenerator &rng)
void add_client_cert(const X509_Certificate &cert, Private_Key *cert_key)
RandomNumberGenerator * rng
std::vector< X509_Certificate > peer_cert_chain() const
size_t read(byte buf[], size_t buf_len)
void set_keys(const CipherSuite &suite, const SessionKeys &keys, Connection_Side side)
void write(const byte[], size_t)
void alert(Alert_Level, Alert_Type)
u32bit make_u32bit(byte i0, byte i1, byte i2, byte i3)
void write(const byte buf[], size_t buf_len)
void send(byte type, const byte input[], size_t length)