>Backouts_
Published on

TLS 시리즈 02: RSA로 키 교환해보기

Authors
  • avatar
    Name
    Backouts
    Twitter

앞의 글에서 RSA를 이해하기 위해
유클리드 알고리즘, 확장 유클리드 알고리즘,
그리고 모듈러 역원까지 하나씩 정리해봤습니다.

이제는 그 이론들을 실제 코드에서 사용해 볼 차례입니다.

이번 글에서는
제가 만들고 있던 TCP 기반 채팅 client, server 코드에 RSA를 직접 붙여보겠습니다.

이번 단계에서 목표

이번 글의 목표는 간단합니다.

  • 서버가 RSA 공개키, 개인키를 만든다.
  • 클라이언트는 서버의 공개키를 받는다.
  • 클라이언트는 랜덤한 숫자 하나를 RSA로 암호화해 서버에게 보낸다.
  • 서버는 그 값을 자기 개인키로 복호화한다.

이 랜덤한 숫자는 이후에
대칭키(AES)로 암호화하기 위한 세션 키 역할을 하게 됩니다.

아직 AES까지는 가지 않고
RSA가 실제로 키를 안전하게 전달하는 과정에 어떻게 쓰이는지를 살펴보겠습니다.


🔥 RSA 구현해보기

RSA의 기능을 하기 위한 모듈인 rsa.py를 생성했습니다.
구조를 간단히 정리하면 다음과 같습니다.

  • 소수 생성
    • Miller–Rabin 테스트 이용
  • 키 생성
    • 서로 다른 큰 소수 p, q
    • p, q => n, φ(n) => e, d
  • 암호화
    • m^e mod n
  • 복호화
    • c^d mod n

즉, 이 파일은 RSA가 돌아가기 위한 수학 엔진 역할입니다.

간단하게
pubkey, privkey = generate_keypair()
이 한 줄의 코드로 공개키(n, e), 개인키(n, d) 를 얻을 수 있습니다.

📌 rsa.py 코드 펼치기
import random

def extended_euclidean(a, b):
    # 확장 유클리드 알고리즘
    # a*x + b*y = gcd(a, b)를 만족하는 x, y를 구한다.
    # RSA에서는 이 x 값이 모듈러 역원이 된다.

    if b == 0:
        # 더 이상 나눌 수 없으면 a가 gcd
        return 1, 0, a

    # 다음 단계로 내려가서 계수들을 먼저 구한다
    x1, y1, g = extended_euclidean(b, a % b)

    # a = b*q + r  (r = a % b)
    # r = a - b*q 를 이용해 계수를 다시 정리
    x = y1
    y = x1 - (a // b) * y1

    return x, y, g

def modinv(a, n):
    # a의 모듈러 역원 (a^-1 mod n)을 구함
    # a * x ≡ 1 (mod n)

    x, _, g = extended_euclidean(a, n)

    # 서로소가 아니면 역원이 없음
    if g != 1:
        raise ValueError("역원이 존재하지 않습니다.")

    # 음수가 나올 수 있으므로 mod n 범위로 정리
    return x % n

# Miller–Rabin 확률적 소수 판별
def is_probable_prime(n, k=10):
    if n < 2:
        return False

    # 작은 소수 예외 처리
    small_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
    if n in small_primes:
        return True
    if any(n % p == 0 for p in small_primes):
        return False

    # n - 1 = 2^r * d (d는 홀수)
    r = 0
    d = n - 1
    while d % 2 == 0:
        r += 1
        d //= 2

    # k번의 랜덤 테스트
    for _ in range(k):
        a = random.randrange(2, n - 2)
        x = pow(a, d, n)

        if x == 1 or x == n - 1:
            continue

        for _ in range(r - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False

    return True

def generate_prime(bits):
    # 지정한 비트 길이의 소수 하나를 생성한다

    while True:
        # 비트 길이를 맞춘 랜덤 후보
        cand = random.getrandbits(bits)
        # MSB = 1
        # bits 만큼의 자릿수 보장
        cand |= (1 << (bits - 1))
        # 홀수로 제한
        cand |= 1

        # Miller–Rabin 통과 시 채택
        if is_probable_prime(cand):
            return cand

def generate_keypair(bits=1024, e=65537):
    # RSA 공개키 / 개인키 생성

    half = bits // 2

    while True:
        # 두 개의 큰 소수 생성
        p = generate_prime(half)
        q = generate_prime(half)
        if p == q:
            continue

        n = p * q
        phi = (p - 1) * (q - 1)

        # e와 phi는 서로소여야 함
        _, _, g = extended_euclidean(e, phi)
        if g != 1:
            continue

        d = modinv(e, phi)

        return (n, e), (n, d)

def rsa_encrypt_int(m, pubkey):
    # RSA 정수 암호화
    # c = m^e mod n
    n, e = pubkey

    if not (0 <= m < n):
        raise ValueError("m must satisfy 0 <= m < n")

    return pow(m, e, n)


def rsa_decrypt_int(c, privkey):
    # RSA 정수 복호화
    # m = c^d mod n
    n, d = privkey

    if not (0 <= c < n):
        raise ValueError("c must satisfy 0 <= c < n")

    return pow(c, d, n)

이 프로젝트에서는 서로 다른 두 소수 p, q를 생성하기 위해
Miller–Rabin 판별법
소수인지 아닌지를 판별하는 용도로 사용했습니다.

간단히 요약하면 다음과 같습니다.

  • n이 소수인지 판별하는 과정:
    1. n − 1을 2의 거듭제곱 형태로 분해합니다. (n − 1 = 2^r · d)
    2. 랜덤한 값 a를 선택해 a^d mod n을 계산합니다.
    3. 이후 제곱을 r-1번 반복하며 mod n 값이 n−1이 나오는지 확인합니다.
    4. 이 과정을 k번 반복해 모두 통과하면 소수일 가능성이 높다고 판단합니다.

이제 이걸 네트워크 코드에 연결해보겠습니다.


❓ 먼저 해결해야 할 문제

RSA 구현보다 먼저 해결해야 할 문제가 하나 있습니다.
소켓은 int를 보낼 수 없다.

RSA 함수는 모두 정수(int) 를 사용합니다.
하지만 소켓 통신은 bytes만 보낼 수 있습니다.

그래서 변환과정이 필요해집니다.
int를 bytes로, bytes를 int로 변환하는 코드를
프레이밍때 사용한 utils.py에 추가했습니다.

또한,
프레이밍을 조금 확장해서
메시지에 태그를 같이 보내기 위한 tagged함수역시 추가했습니다.

📌 utils.py 코드 펼치기
import struct

# 바이트 블록을 송신하는 함수
def send_block(sock, data):
    length = len(data)
    # 데이터의 길이를 4바이트로 패킹하여 전송함
    sock.sendall(struct.pack("!I", length) + data)

# 바이트 블록을 수신하는 함수
def recv_block(sock):
    # 4바이트를 먼저 수신하여 데이터 전체의 길이를 알아냄
    # 실전에서는 recv_exact 같은 함수로 4바이트 수신을 보장하는 것이 더 안전함
    header = sock.recv(4)
    if not header:
        # 헤더가 없을때 에러
        raise ConnectionError("헤더가 없습니다.")
    length = struct.unpack("!I", header)[0]

    buf = b''
    # 지정된 길이만큼의 데이터를 수신하여 버퍼에 저장
    while len(buf) < length:
        chunk = sock.recv(length - len(buf))
        if not chunk:
            raise ConnectionError("데이터 수신에 실패했습니다.")
        buf += chunk
    return buf

# ================================================
# 추가 된 부분
# ================================================

# int를 byte로
def int_to_bytes(x):
    length = max(1, (x.bit_length() + 7) // 8)
    return x.to_bytes(length, "big")

# byte를 int로
def bytes_to_int(b):
    return int.from_bytes(b, "big")

# 메시지를 태그와 함께 보내기 ("PUB", "PRI" 등)
def send_tagged(sock, tag, payload):
    # Enum일 때 값 꺼내기
    if hasattr(tag, "value"):
        tag = tag.value
    send_block(sock, tag.encode("ascii"))
    send_block(sock, payload)

def recv_tagged(sock):
    tag = recv_block(sock).decode("ascii")
    payload = recv_block(sock)
    return tag, payload

또한,
태그에 대한 오타를 줄이기 위해 protocol.pyEnum을 추가했습니다.

📌 protocol.py 코드 펼치기
from enum import Enum

class MsgType(Enum):
    # RSA 핸드셰이크 관련
    PUB_N = "PUBN"   # 서버 공개키 n
    PUB_E = "PUBE"   # 서버 공개키 e
    KEY_C = "KEYC"   # 클라이언트가 보낸 세션키(RSA 암호문)

    # 이후 채팅 기능 확장용
    CHAT  = "CHAT"   # 일반 채팅 메시지
    CLOSE = "CLOSE"  # 정상 종료
    ERROR = "ERROR"  # 오류 알림

🎉 채팅앱에 RSA 탑재하기

와 드디어 여기까지 왔습니다.
이 부분은 사실 매우 간단합니다.
위에서 생성한 RSA 방식을 사용해 통신만 하면 되니까요

💻 서버에서 공개키 생성과 복호화

서버의 역할은 단순합니다.

  • RSA 키 생성
  • 공개키 전송
  • 암호화된 세션 키를 받아 복호화
pubkey, privkey = generate_keypair()
n, e = pubkey

서버는 연결이 들어오면
이렇게 RSA 키를 먼저 생성합니다.

그 다음 공개키 (n, e)를 클라이언트에게 보냅니다.

send_tagged(conn, MsgType.PUB_N, int_to_bytes(n))
send_tagged(conn, MsgType.PUB_E, int_to_bytes(e))

클라이언트가 보내준 암호문은
개인키로 복호화합니다.

tag, payload = recv_tagged(conn)
if tag != MsgType.KEY_C.value:
    raise ValueError(f"unexpected tag: {tag}")

c = bytes_to_int(payload)
m = rsa_decrypt_int(c, privkey)
session_key = m.to_bytes(32, "big")

이 시점에서 서버는
클라이언트가 만든 세션 키를 정확히 복원하게 됩니다.

📌 server_rsa.py 코드 펼치기
import socket
import threading

from utils import recv_block, send_block, send_tagged, recv_tagged, int_to_bytes, bytes_to_int
from rsa import generate_keypair, rsa_decrypt_int
from protocol import MsgType

HOST = '127.0.0.1'
PORT = 30003
stop_event = threading.Event()

def recv_thread(conn):
    try:
        while not stop_event.is_set():
            data = recv_block(conn)
            if not data:
                break
            print(f"[Client] {data.decode()}")
            print("> ", end="", flush=True)
    except (ConnectionResetError, OSError) as e:
        print(f'\n[*] 수신 중 오류: {e}')
    finally:
        stop_event.set()
        try:
            conn.close()
        except:
            pass

def send_thread(conn):
    try:
        while not stop_event.is_set():
            try:
                print("> ", end="", flush=True)
                msg = input()
            except (KeyboardInterrupt, EOFError):
                print('\n[*] 서버 종료 중...')
                stop_event.set()
                break

            if not msg:
                continue

            try:
                send_block(conn, msg.encode())
            except (BrokenPipeError, OSError, ConnectionResetError) as e:
                print(f'\n[*] 메시지 전송 오류: {e}')
                stop_event.set()
                break
    finally:
        try:
            conn.close()
        except:
            pass

def rsa_handshake_server(conn):
    # 1. 서버가 RSA 키 생성
    pubkey, privkey = generate_keypair(bits=1024, e=65537)
    n, e = pubkey

    # 2. 공개키 전송: tag="PUB", payload = n_bytes or e_bytes
    # tagged를 2번 사용해서 n과 e 전송
    send_tagged(conn, MsgType.PUB_N, int_to_bytes(n))
    send_tagged(conn, MsgType.PUB_E, int_to_bytes(e))

    # 3. 클라이언트가 보낸 세션키 암호문 받기
    tag, payload = recv_tagged(conn)
    if tag != MsgType.KEY_C.value:
        raise ValueError(f"unexpected tag: {tag}")

    c = bytes_to_int(payload)
    m = rsa_decrypt_int(c, privkey)

    # 4. 세션키 길이 고정(32바이트)로 복원 (클라이언트도 32바이트)
    session_key = m.to_bytes(32, "big")
    return session_key

def main():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    try:
        sock.bind((HOST, PORT))
        sock.listen()
        print(f"[Server] {HOST}:{PORT} 대기...")

        conn, addr = sock.accept()
        print(f"[Server] {addr} 접속 허용")

        # 여기서 RSA 핸드셰이크 시작
        try:
            session_key = rsa_handshake_server(conn)
            print(f"[*] RSA 세션키 교환 완료 (앞 8바이트): {session_key[:8].hex()}")
        except Exception as e:
            print(f"[*] RSA 핸드셰이크 실패: {e}")
            stop_event.set()
            conn.close()
            return

        # 이후 채팅 스레드 시작
        recv_t = threading.Thread(target=recv_thread, args=(conn,))
        send_t = threading.Thread(target=send_thread, args=(conn,))
        recv_t.start()
        send_t.start()

        try:
            while not stop_event.is_set():
                recv_t.join(timeout=0.5)
                send_t.join(timeout=0.5)
                if not recv_t.is_alive() or not send_t.is_alive():
                    stop_event.set()
        except KeyboardInterrupt:
            print("\n[*] `Ctrl + C 감지` 서버 종료 중...")
            stop_event.set()

    finally:
        try:
            sock.close()
        except:
            pass
        print("\n[Server] 서버 종료")

if __name__ == "__main__":
    main()

🔎 클라이언트에서 공개키 수신과 암호화

클라이언트 쪽 흐름도 단순합니다.

먼저 서버로부터 공개키를 받습니다.

tag, n_bytes = recv_tagged(sock)
n = bytes_to_int(n_bytes)
tag, e_bytes = recv_tagged(sock)
e = bytes_to_int(e_bytes)
pubkey = (n, e) # 공개키

그 다음 랜덤한 숫자를 하나 만듭니다.

session_key = os.urandom(32)
m = bytes_to_int(session_key)

m 값은 반드시 0 ≤ m < n 범위 안에 있어야 합니다.

현재 RSA 키 길이가 1024비트이기 때문에 32바이트(256비트) 세션 키는 항상 n보다 작습니다.

이제 이 값을 RSA로 암호화합니다.

c = rsa_encrypt_int(m, pubkey)
send_tagged(sock, MsgType.KEY_C, int_to_bytes(c))

이렇게 해서
서버만 복호화할 수 있는
공개키로 암호화 한 세션키 값을 전송하게 됩니다.

📌 client_rsa.py 코드 펼치기
import socket
import threading
import os

from utils import recv_block, send_block, send_tagged, recv_tagged, int_to_bytes, bytes_to_int
from rsa import rsa_encrypt_int
from protocol import MsgType

HOST = '127.0.0.1'
PORT = 30003

stop_event = threading.Event()

def recv_thread(sock):
    try:
        while not stop_event.is_set():
            data = recv_block(sock)
            if not data:
                print("\n[*] 서버 연결이 끊어졌습니다.")
                break
            print(f"[Server] {data.decode()}")
            print("> ", end="", flush=True)
    except (ConnectionResetError, OSError) as e:
        print(f'\n[*] 수신 중 오류: {e}')
    finally:
        stop_event.set()
        try:
            sock.close()
        except:
            pass

def send_thread(sock):
    try:
        while not stop_event.is_set():
            try:
                print("> ", end="", flush=True)
                msg = input()
            except (KeyboardInterrupt, EOFError):
                print("\n[*] 종료.")
                stop_event.set()
                break

            if not msg:
                continue

            try:
                send_block(sock, msg.encode())
            except (BrokenPipeError, OSError, ConnectionResetError) as e:
                print(f'\n[*] 메시지 전송 오류: {e}')
                stop_event.set()
                break
    finally:
        try:
            sock.close()
        except:
            pass

def rsa_handshake_client(sock):
    # 1. 서버 공개키 받기
    tag, n_bytes = recv_tagged(sock)
    if tag != MsgType.PUB_N.value:
        raise ValueError(f"unexpected tag: {tag}")
    n = bytes_to_int(n_bytes)

    tag, e_bytes = recv_tagged(sock)
    if tag != MsgType.PUB_E.value:
        raise ValueError(f"unexpected tag: {tag}")
    e = bytes_to_int(e_bytes)

    pubkey = (n, e)

    # 2. 세션키 생성 (32바이트 고정)
    session_key = os.urandom(32)

    # 3. bytes => int
    m = bytes_to_int(session_key)

    # 4. RSA 암호화
    c = rsa_encrypt_int(m, pubkey)

    # 5. 암호문 전송 (정수=>바이트)
    send_tagged(sock, MsgType.KEY_C, int_to_bytes(c))

    return session_key

def main():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    try:
        sock.connect((HOST, PORT))
        print(f'[Client] {HOST}:{PORT} 서버에 접속했습니다.')

        # 여기서 RSA 핸드셰이크 시작
        try:
            session_key = rsa_handshake_client(sock)
            print(f"[*] RSA 세션키 교환 완료 (앞 8바이트): {session_key[:8].hex()}")
        except Exception as e:
            print(f"[*] RSA 핸드셰이크 실패: {e}")
            stop_event.set()
            sock.close()
            return

        # 채팅 스레드 시작
        recv_t = threading.Thread(target=recv_thread, args=(sock,))
        send_t = threading.Thread(target=send_thread, args=(sock,))
        recv_t.start()
        send_t.start()

        try:
            while not stop_event.is_set():
                recv_t.join(0.5)
                send_t.join(0.5)
                if not recv_t.is_alive() or not send_t.is_alive():
                    stop_event.set()
        except KeyboardInterrupt:
            print("\n[*] `Ctrl+C 감지` 클라이언트 종료.")
            stop_event.set()

    finally:
        try:
            sock.close()
        except:
            pass
        print("\n[Client] 클라이언트 종료.")

if __name__ == '__main__':
    main()

server_rsa.py 실행
server_rsa.py 실행
client_rsa.py 실행
client_rsa.py 실행

❓ 왜 RSA로 메시지를 안 보내고 키만 보낼까?

여기서 한 가지 의문이 듭니다.

그냥 RSA로 채팅 내용을 다 암호화하면 안 되나?

가능은 하지만, 실제로는 그렇게 하지 않습니다.

RSA는 계산이 느리기 때문에
긴 데이터를 처리하기엔 부담이 크기 때문입니다.

그래서 실제로는 다음과 같이 사용됩니다.

  • RSA
    • 처음 한 번, 키 교환
  • AES
    • 그 이후 모든 데이터 통신

이번 글은 그 중
RSA가 담당하는 역할만 구현한 것입니다.


정리

이번 글에서는
직접 구현한 rsa.py를 사용해서
클라이언트와 서버가
RSA를 통해 공통의 비밀 키 하나를 교환하는 과정을 구현해봤습니다.

이제 드디어 다음 글에서는 이 세션 키를 이용해서
AES로 메시지를 암호화하고 통신하는 구조로
자연스럽게 확장할 수 있게 되었습니다.