안녕하세요 평범한 대학원생입니다.
연구를 하면서
다른 논문의 이론을 보고 그 내용을 짰는데
논문에서 말한 계산시간이 4초인 반면에
제가 짠 코드는 2시간이 넘게 걸리는 것을 발견했습니다.
그래서 제가 짠 코드가 굉장히 비효율적으로 짜여있는거 같은데
어느 부분이 비효율적인지 잘 모르겠습니다.
혹시 파이썬 잘 아시는 분들
조언해주시면 감사하겠습니다!!
본삭금했기 때문에 글 지우는 일 없을겁니다!
# -*- coding: utf-8 -*-
from __future__ import division
import numpy as np
import time
def tranprob(kappa,alpha,sigma,x,regime,l,sba,h,xup,xdown):
#스테이트의 transition probability 레짐 transition probability과는 다른거임
#regime은 그냥 parameter로 때움
#l=0,1,2 <=> t,m,b
#xup 과 xdown은 각각 state variable의 upperbound와 lowerbound
lamda = kappa * (alpha - x)
indcase = 1 * (x < xdown[regime,0]) + 2 * (x > xup[regime,0])#중간:0 작은경우:1 큰경우:2
pt = (indcase==0) * (sigma[regime,0] ** 2 + lamda[regime,0] * l[regime,0] * sba * np.sqrt(h) + lamda[regime,0] ** 2 * h) / (2 * (l[regime,0] * sba) ** 2)\
+ (indcase==1) * (sigma[regime,0] ** 2 - lamda[regime,0] * l[regime,0] * sba * np.sqrt(h) + lamda[regime,0] ** 2 * h) / (2 * (l[regime,0] * sba) ** 2)\
+ (indcase==2) * (1 + (sigma[regime,0] ** 2 + 3 * lamda[regime,0] * l[regime,0] * sba * np.sqrt(h) + lamda[regime,0] ** 2 * h) / (2 * (l[regime,0] * sba) ** 2))
pm = (indcase==0) * (sigma[regime,0] ** 2 - lamda[regime,0] * l[regime,0] * sba * np.sqrt(h) + lamda[regime,0] ** 2 * h) / (2 * (l[regime,0] * sba) ** 2)\
+ (indcase==1) * (1 -(sigma[regime,0] ** 2 + lamda[regime,0] ** 2*h) / (l[regime,0] * sba) ** 2)\
+ (indcase==2) * -(sigma[regime,0] ** 2 + 2 * lamda[regime,0] * l[regime,0] * sba * np.sqrt(h) + lamda[regime,0] ** 2 * h) / (l[regime,0] * sba) ** 2
pd = 1-pt-pm
return (indcase,pt,pm,pd)
--------------
여기까지가 제가 작성한 함수입니다. 이아래 부터가 본문인데
--------------
ttm = 1
S0 = 100
x0 = 0
kappa = np.array([[0.5],[1]])
sigma = np.array([[0.15],[0.25]])
alpha = np.array([[0.0275],[0.06875]])#simplex 방법을 사용하기 위해서는 +x0를 해야 됨
r = np.array([[0.06],[0.06]])
regime = 0
L = np.array([[0.5],[0.5]])
K = 100
sba =0.1 #sigma ba:state difference
n = 1000#time-step number
ind = 2#1 for call, 2 for put
h = ttm / n # time difference => t=0, h, 2h, ...,
p = np.array([[(L[0,0] * np.exp(-(L[0,0]+L[1,0])*h) + L[1,0]) / (L[0,0]+L[1,0]),0],[0,(L[1,0] * np.exp(-(L[0,0]+L[1,0])*h) + L[0,0])/(L[0,0]+L[1,0])]])
p[0,1] = 1 - p[0,0]
p[1,0] = 1 - p[1,1]
l = np.floor(2 * sigma / sba); l = l.astype(int)
maxL = np.max(l)
xup = alpha + (l * sba - np.sqrt((l * sba) ** 2 - sigma ** 2)) / (kappa * np.sqrt(h))
xdown = alpha - (l * sba - np.sqrt((l * sba) ** 2 - sigma ** 2)) / (kappa * np.sqrt(h))
dx = sba * np.sqrt(h)
grid = dx * np.arange(-2 * maxL * n,2 * maxL * n+1)
#만기 payoff
payoff = (ind == 1) * (S0 * np.exp(grid) - K) + (ind == 2) * (K - S0 * np.exp(grid))
payoff[payoff<0] = 0
valuebefore = np.concatenate(([payoff],[payoff]),axis=0)
tt=time.time()
--------
제가 궁금한 것은 이 for문입니다. 이 for 문이 돌려보면
거의 2시간 넘게 돌아야 결과 하나가 나오게끔 나옵니다.
이를 더 효율적으로 짤 수 있는 방법이 있을까요?
--------
for t in xrange(1,n+1):#
valueafter = np.zeros((2,4 * maxL * n - 4 * maxL * t +1))
for k in xrange(0,4 * maxL * (n-t) +1):
for i in xrange(0,2):#state regime
for j in xrange(0,2):#transition regime
(indcase,pt,pm,pd) = tranprob(kappa,alpha,sigma,grid[k + 2 * maxL * t],i,l,sba,h,xup,xdown)
if (indcase == 0):
valueafter[i,k] = valueafter[i,k] + np.exp(-r[i] * h) * p[i,j] * pt * valuebefore[j,2 * maxL + k + l[i,0] * 1] \
+ np.exp(-r[i] * h) * p[i,j] * pm * valuebefore[j,2 * maxL + k] \
+ np.exp(-r[i] * h) * p[i,j] * pd * valuebefore[j,2 * maxL + k - l[i,0] * 1 ]
elif (indcase == 1):
valueafter[i,k] = valueafter[i,k] + np.exp(-r[i] * h) * p[i,j] * pt * valuebefore[j,2 * maxL + (indcase == 1) * (k + l[i,0] * 2) ] \
+ np.exp(-r[i] * h) * p[i,j] * pm * valuebefore[j,2 * maxL + k + l[i,0] * 1 ] \
+ np.exp(-r[i] * h) * p[i,j] * pd * valuebefore[j,2 * maxL + k]
elif (indcase ==2):
valueafter[i,k] = valueafter[i,k] + np.exp(-r[i] * h) * p[i,j] * pt * valuebefore[j,2 * maxL + k] \
+ np.exp(-r[i] * h) * p[i,j] * pm * valuebefore[j,2 * maxL + k - l[i,0] * 1] \
+ np.exp(-r[i] * h) * p[i,j] * pd * valuebefore[j,2 * maxL + k - l[i,0] * 2]
valuebefore = valueafter
print(valueafter)
elapsed = time.time() - tt#
print(elapsed)#
------------------------------
위에서 굵은 색으로 한 tranprob 함수는 제가 직접 짠 함수이며, 위에 그 코드가 적혀있습니다.
전 이 코드가 왜 비효율적인지 이해가 잘 되지 않습니다.
지식을 내려주시기 바랍니다 ㅠㅠ